diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2020-01-10 13:01:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-10 13:01:26 +0100 |
commit | 1bfeb920e039dd22f586c382c66fef90af6f4459 (patch) | |
tree | 1d9ac3ffee2ecb4defbc143c30d088ccbfcc3086 | |
parent | f99a8f34016a3a40f19893cf903737e701606fc5 (diff) | |
parent | f114bedb76443cf68cfbf98769e41f1d6e4b9932 (diff) |
Merge pull request #11740 from vespa-engine/lesters/resolve-input-types-for-stateless-model-evaluation
Add resolving of input types for stateless model evaluation
4 files changed, 30 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 301aa1faa83..172e538d708 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -102,8 +102,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + " -> " + reference); - - // Bound toi a function argument, and not to a same-named identifier (which would lead to a loop)? + // Bound to a function argument, and not to a same-named identifier (which would lead to a loop)? Optional<String> binding = boundIdentifier(reference); if (binding.isPresent() && ! binding.get().equals(reference.toString())) { try { @@ -139,6 +138,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return featureTensorType.get(); } + // A directly injected identifier? (Useful for stateless model evaluation) + if (reference.isIdentifier() && featureTypes.containsKey(reference)) { + return featureTypes.get(reference); + } + // We do not know what this is - since we do not have complete knowledge about the match features // in Java we must assume this is a match feature and return the double type - which is the type of // all match features diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index c42da6dcd19..c3d6f457ce8 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -387,12 +387,25 @@ public class ConvertedModel { private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - // Check generated functions for inputs to reduce + // Add any missing inputs for type resolution Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { + Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); + if (requiredType.isPresent()) { + Reference ref = Reference.fromIdentifier(functionName); + if (typeContext.getType(ref).equals(TensorType.empty)) { + typeContext.setType(ref, requiredType.get()); + } + } + } + typeContext.forgetResolvedTypes(); + + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated functions for inputs to reduce + for (String functionName : functionNames) { if ( ! model.functions().containsKey(functionName)) continue; RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 216f90240f7..77b25489047 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -121,26 +121,32 @@ public class ModelEvaluationTest { Model tensorflow_mnist = evaluator.models().get("mnist_saved"); assertNotNull(tensorflow_mnist); + assertEquals(1, tensorflow_mnist.functions().size()); assertNotNull(tensorflow_mnist.evaluatorOf("serving_default")); assertNotNull(tensorflow_mnist.evaluatorOf("serving_default", "y")); assertNotNull(tensorflow_mnist.evaluatorOf("serving_default.y")); assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default.y")); assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default", "y")); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), tensorflow_mnist.functions().get(0).argumentTypes().get("input")); Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax"); assertNotNull(onnx_mnist_softmax); + assertEquals(1, onnx_mnist_softmax.functions().size()); assertNotNull(onnx_mnist_softmax.evaluatorOf()); assertNotNull(onnx_mnist_softmax.evaluatorOf("default")); assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add")); assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), onnx_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder")); Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved"); assertNotNull(tensorflow_mnist_softmax); + assertEquals(1, tensorflow_mnist_softmax.functions().size()); assertNotNull(tensorflow_mnist_softmax.evaluatorOf()); assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default")); assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y")); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), tensorflow_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder")); } private final String mnistProfile = diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index a9be1bbd40e..47fe66dd424 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -64,6 +64,9 @@ public class ImportedModel implements ImportedMlModel { @Override public String source() { return source; } + @Override + public String toString() { return "imported model '" + name + "' from " + source; } + /** Returns an immutable map of the inputs of this */ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } |