summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2020-01-10 13:01:26 +0100
committerGitHub <noreply@github.com>2020-01-10 13:01:26 +0100
commit1bfeb920e039dd22f586c382c66fef90af6f4459 (patch)
tree1d9ac3ffee2ecb4defbc143c30d088ccbfcc3086
parentf99a8f34016a3a40f19893cf903737e701606fc5 (diff)
parentf114bedb76443cf68cfbf98769e41f1d6e4b9932 (diff)
Merge pull request #11740 from vespa-engine/lesters/resolve-input-types-for-stateless-model-evaluation
Add resolving of input types for stateless model evaluation
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java17
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java3
4 files changed, 30 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 301aa1faa83..172e538d708 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -102,8 +102,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
" -> " + reference);
-
- // Bound toi a function argument, and not to a same-named identifier (which would lead to a loop)?
+ // Bound to a function argument, and not to a same-named identifier (which would lead to a loop)?
Optional<String> binding = boundIdentifier(reference);
if (binding.isPresent() && ! binding.get().equals(reference.toString())) {
try {
@@ -139,6 +138,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return featureTensorType.get();
}
+ // A directly injected identifier? (Useful for stateless model evaluation)
+ if (reference.isIdentifier() && featureTypes.containsKey(reference)) {
+ return featureTypes.get(reference);
+ }
+
// We do not know what this is - since we do not have complete knowledge about the match features
// in Java we must assume this is a match feature and return the double type - which is the type of
// all match features
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index c42da6dcd19..c3d6f457ce8 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -387,12 +387,25 @@ public class ConvertedModel {
private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
- // Check generated functions for inputs to reduce
+ // Add any missing inputs for type resolution
Set<String> functionNames = new HashSet<>();
addFunctionNamesIn(expression.getRoot(), functionNames, model);
for (String functionName : functionNames) {
+ Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
+ if (requiredType.isPresent()) {
+ Reference ref = Reference.fromIdentifier(functionName);
+ if (typeContext.getType(ref).equals(TensorType.empty)) {
+ typeContext.setType(ref, requiredType.get());
+ }
+ }
+ }
+ typeContext.forgetResolvedTypes();
+
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated functions for inputs to reduce
+ for (String functionName : functionNames) {
if ( ! model.functions().containsKey(functionName)) continue;
RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 216f90240f7..77b25489047 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -121,26 +121,32 @@ public class ModelEvaluationTest {
Model tensorflow_mnist = evaluator.models().get("mnist_saved");
assertNotNull(tensorflow_mnist);
+ assertEquals(1, tensorflow_mnist.functions().size());
assertNotNull(tensorflow_mnist.evaluatorOf("serving_default"));
assertNotNull(tensorflow_mnist.evaluatorOf("serving_default", "y"));
assertNotNull(tensorflow_mnist.evaluatorOf("serving_default.y"));
assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default.y"));
assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default", "y"));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), tensorflow_mnist.functions().get(0).argumentTypes().get("input"));
Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax");
assertNotNull(onnx_mnist_softmax);
+ assertEquals(1, onnx_mnist_softmax.functions().size());
assertNotNull(onnx_mnist_softmax.evaluatorOf());
assertNotNull(onnx_mnist_softmax.evaluatorOf("default"));
assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add"));
assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add"));
assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add"));
assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), onnx_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder"));
Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved");
assertNotNull(tensorflow_mnist_softmax);
+ assertEquals(1, tensorflow_mnist_softmax.functions().size());
assertNotNull(tensorflow_mnist_softmax.evaluatorOf());
assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default"));
assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y"));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), tensorflow_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder"));
}
private final String mnistProfile =
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index a9be1bbd40e..47fe66dd424 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -64,6 +64,9 @@ public class ImportedModel implements ImportedMlModel {
@Override
public String source() { return source; }
+ @Override
+ public String toString() { return "imported model '" + name + "' from " + source; }
+
/** Returns an immutable map of the inputs of this */
public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }