diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
commit | 1d88554bd513783715425120e76fc5f2a86f439f (patch) | |
tree | 166c86107d3620014cc7e26d85118c311e1b8cf0 /model-integration | |
parent | a01bc21d9bcbc417a9fb2591079561f59f76865e (diff) |
Java type only interface between imported-models and config models
This avoids class incompatibility problems when passing an
imported model across bundle boundaries to a config model.
Tensor string parsing has been sped up as this relies on it more.
Diffstat (limited to 'model-integration')
6 files changed, 116 insertions, 38 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 2866a2c76b2..c2235b9abe9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; @@ -59,18 +60,29 @@ public class ImportedModel { /** Returns an immutable map of the inputs of this */ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } + // CFG + public Optional<String> inputTypeSpec(String input) { + return Optional.ofNullable(inputs.get(input)).map(TensorType::toString); + } + /** - * Returns an immutable map of the small constants of this. + * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form. * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + // CFG + public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } + + boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } /** * Returns an immutable map of the large constants of this. * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. * For TensorFlow this corresponds to Variable files stored separately. */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + // CFG + public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } + + boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } /** * Returns an immutable map of the expressions of this - corresponding to graph nodes @@ -79,11 +91,14 @@ public class ImportedModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + // TODO: Most of the usage of the above can be replaced by a faster expressionNames method + /** * Returns an immutable map of the functions that are part of this model. * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } + // CFG + public Map<String, String> functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -108,43 +123,60 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionFunction>> outputExpressions() { - List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); + // CFG + public List<ImportedFunction> outputExpressions() { + List<ImportedFunction> functions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(), + signatureEntry.getKey() + "." + outputEntry.getKey())); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty()))); + functions.add(new ImportedFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty())); } } } - return expressions; + return functions; + } + + private Map<String, String> asTensorStrings(Map<String, Tensor> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, Tensor> entry : map.entrySet()) { + Tensor tensor = entry.getValue(); + // TODO: See Tensor.toStandardString + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) + values.put(entry.getKey(), tensor.toString()); + else + values.put(entry.getKey(), tensor.type() + ":" + tensor); + } + return values; + } + + private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, RankingExpression> entry : map.entrySet()) + values.put(entry.getKey(), entry.getValue().getRoot().toString()); + return values; } /** @@ -213,6 +245,17 @@ public class ImportedModel { Optional.empty()); } + /** Returns the expression this output references as an imported function */ + public ImportedFunction outputFunction(String outputName, String functionName) { + return new ImportedFunction(functionName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); + } + + // CFG + @Override public String toString() { return "signature '" + name + "'"; } @@ -223,4 +266,37 @@ public class ImportedModel { } + // CFG + public static class ImportedFunction { + + private final String name; + private final List<String> arguments; + private final Map<String, String> argumentTypes; + private final String expression; + private final Optional<String> returnType; + + public ImportedFunction(String name, List<String> arguments, RankingExpression expression, + Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { + this.name = name; + this.arguments = arguments; + this.expression = expression.getRoot().toString(); + this.argumentTypes = asStrings(argumentTypes); + this.returnType = returnType.map(TensorType::toString); + } + + private static Map<String, String> asStrings(Map<String, TensorType> map) { + Map<String, String> stringMap = new HashMap<>(); + for (Map.Entry<String, TensorType> entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + + public String name() { return name; } + public List<String> arguments() { return Collections.unmodifiableList(arguments); } + public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); } + public String expression() { return expression; } + public Optional<String> returnType() { return returnType; } + + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java index 1b7532631e1..bfdaaca1dd7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java @@ -69,6 +69,7 @@ public class ImportedModels { * models directory works * @return the model at this path or null if none */ + // CFG public ImportedModel get(File modelPath) { return importedModels.get(toName(modelPath)); } @@ -78,6 +79,7 @@ public class ImportedModels { } /** Returns an immutable collection of all the imported models */ + // CFG public Collection<ImportedModel> all() { return importedModels.values(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index cb095e81147..8a885938bf9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -121,7 +121,7 @@ public abstract class ModelImporter { private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) { return operation.function(); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index d3996da9b58..315456c2613 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -28,13 +28,13 @@ public class OnnxMnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("test_Variable"); + Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("test_Variable_1"); + Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -84,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase { private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index 6215997d8f9..be676186017 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -24,13 +24,13 @@ public class TensorFlowMnistSoftmaxImportTestCase { // Check constants Assert.assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); + Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); + Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index c3b82cccb46..4ff0c96d369 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -93,8 +93,8 @@ public class TestableTensorFlowModel { static Context contextFrom(ImportedModel result) { TestableModelContext context = new TestableModelContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } @@ -108,7 +108,7 @@ public class TestableTensorFlowModel { private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { - RankingExpression e = model.functions().get(functionName); + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); evaluateFunctionDependencies(context, model, e.getRoot()); context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } |