summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-25 20:07:56 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-25 20:07:56 +0100
commit1d88554bd513783715425120e76fc5f2a86f439f (patch)
tree166c86107d3620014cc7e26d85118c311e1b8cf0 /model-integration
parenta01bc21d9bcbc417a9fb2591079561f59f76865e (diff)
Java type only interface between imported-models and config models
This avoids class incompatibility problems when passing an imported model across bundle boundaries to a config model. Tensor string parsing has been sped up as this relies on it more.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java132
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java8
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java6
6 files changed, 116 insertions, 38 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 2866a2c76b2..c2235b9abe9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
@@ -59,18 +60,29 @@ public class ImportedModel {
/** Returns an immutable map of the inputs of this */
public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
+ // CFG
+ public Optional<String> inputTypeSpec(String input) {
+ return Optional.ofNullable(inputs.get(input)).map(TensorType::toString);
+ }
+
/**
- * Returns an immutable map of the small constants of this.
+ * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form.
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
- public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+ // CFG
+ public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
+
+ boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
/**
* Returns an immutable map of the large constants of this.
* These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
* For TensorFlow this corresponds to Variable files stored separately.
*/
- public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
+ // CFG
+ public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
+
+ boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
/**
* Returns an immutable map of the expressions of this - corresponding to graph nodes
@@ -79,11 +91,14 @@ public class ImportedModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+ // TODO: Most of the usage of the above can be replaced by a faster expressionNames method
+
/**
* Returns an immutable map of the functions that are part of this model.
* Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
*/
- public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); }
+ // CFG
+ public Map<String, String> functions() { return asExpressionStrings(functions); }
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -108,43 +123,60 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public List<Pair<String, ExpressionFunction>> outputExpressions() {
- List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
+ // CFG
+ public List<ImportedFunction> outputExpressions() {
+ List<ImportedFunction> functions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
- signatureEntry.getValue().outputExpression(outputEntry.getKey())
- .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
+ functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(),
+ signatureEntry.getKey() + "." + outputEntry.getKey()));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- expressions.add(new Pair<>(signatureEntry.getKey(),
- new ExpressionFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().values()),
- expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap(),
- Optional.empty())));
+ functions.add(new ImportedFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- expressions.add(new Pair<>(singleEntry.getKey(),
- new ExpressionFunction(singleEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- singleEntry.getValue(),
- inputs,
- Optional.empty())));
+ functions.add(new ImportedFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty()));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- expressions.add(new Pair<>(expressionEntry.getKey(),
- new ExpressionFunction(expressionEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- expressionEntry.getValue(),
- inputs,
- Optional.empty())));
+ functions.add(new ImportedFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty()));
}
}
}
- return expressions;
+ return functions;
+ }
+
+ private Map<String, String> asTensorStrings(Map<String, Tensor> map) {
+ HashMap<String, String> values = new HashMap<>();
+ for (Map.Entry<String, Tensor> entry : map.entrySet()) {
+ Tensor tensor = entry.getValue();
+ // TODO: See Tensor.toStandardString
+ if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty())
+ values.put(entry.getKey(), tensor.toString());
+ else
+ values.put(entry.getKey(), tensor.type() + ":" + tensor);
+ }
+ return values;
+ }
+
+ private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
+ HashMap<String, String> values = new HashMap<>();
+ for (Map.Entry<String, RankingExpression> entry : map.entrySet())
+ values.put(entry.getKey(), entry.getValue().getRoot().toString());
+ return values;
}
/**
@@ -213,6 +245,17 @@ public class ImportedModel {
Optional.empty());
}
+ /** Returns the expression this output references as an imported function */
+ public ImportedFunction outputFunction(String outputName, String functionName) {
+ return new ImportedFunction(functionName,
+ new ArrayList<>(inputs.values()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
+ }
+
+ // CFG
+
@Override
public String toString() { return "signature '" + name + "'"; }
@@ -223,4 +266,37 @@ public class ImportedModel {
}
+ // CFG
+ public static class ImportedFunction {
+
+ private final String name;
+ private final List<String> arguments;
+ private final Map<String, String> argumentTypes;
+ private final String expression;
+ private final Optional<String> returnType;
+
+ public ImportedFunction(String name, List<String> arguments, RankingExpression expression,
+ Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
+ this.name = name;
+ this.arguments = arguments;
+ this.expression = expression.getRoot().toString();
+ this.argumentTypes = asStrings(argumentTypes);
+ this.returnType = returnType.map(TensorType::toString);
+ }
+
+ private static Map<String, String> asStrings(Map<String, TensorType> map) {
+ Map<String, String> stringMap = new HashMap<>();
+ for (Map.Entry<String, TensorType> entry : map.entrySet())
+ stringMap.put(entry.getKey(), entry.getValue().toString());
+ return stringMap;
+ }
+
+ public String name() { return name; }
+ public List<String> arguments() { return Collections.unmodifiableList(arguments); }
+ public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); }
+ public String expression() { return expression; }
+ public Optional<String> returnType() { return returnType; }
+
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
index 1b7532631e1..bfdaaca1dd7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
@@ -69,6 +69,7 @@ public class ImportedModels {
* models directory works
* @return the model at this path or null if none
*/
+ // CFG
public ImportedModel get(File modelPath) {
return importedModels.get(toName(modelPath));
}
@@ -78,6 +79,7 @@ public class ImportedModels {
}
/** Returns an immutable collection of all the imported models */
+ // CFG
public Collection<ImportedModel> all() {
return importedModels.values();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index cb095e81147..8a885938bf9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -121,7 +121,7 @@ public abstract class ModelImporter {
private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) {
return operation.function();
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index d3996da9b58..315456c2613 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -28,13 +28,13 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check constants
assertEquals(2, model.largeConstants().size());
- Tensor constant0 = model.largeConstants().get("test_Variable");
+ Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable"));
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.largeConstants().get("test_Variable_1");
+ Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1"));
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
@@ -84,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase {
private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
return context;
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index 6215997d8f9..be676186017 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -24,13 +24,13 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// Check constants
Assert.assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
+ Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read"));
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
+ Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read"));
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index c3b82cccb46..4ff0c96d369 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -93,8 +93,8 @@ public class TestableTensorFlowModel {
static Context contextFrom(ImportedModel result) {
TestableModelContext context = new TestableModelContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
return context;
}
@@ -108,7 +108,7 @@ public class TestableTensorFlowModel {
private void evaluateFunction(Context context, ImportedModel model, String functionName) {
if (!context.names().contains(functionName)) {
- RankingExpression e = model.functions().get(functionName);
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
evaluateFunctionDependencies(context, model, e.getRoot());
context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
}