diff options
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java | 7 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 30 |
2 files changed, 27 insertions, 10 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 36b11cee067..9a88b2a31f6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -142,8 +142,15 @@ public class ModelEvaluationTest { assertNotNull(onnx_mnist_softmax.evaluatorOf("default")); assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add")); assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("add")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default", "add")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); + assertNotNull(evaluator.evaluatorOf("mnist_softmax", "add")); + assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default.add")); + assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default", "add")); assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), onnx_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder")); Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved"); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 90700acc0d3..03bbb436026 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -129,7 +129,7 @@ public class Model { } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ - private LazyArrayContext requireContextProprotype(String name) { + private LazyArrayContext requireContextPrototype(String name) { LazyArrayContext context = contextPrototypes.get(name); if (context == null) // Implies function is not present throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " + @@ -183,21 +183,31 @@ public class Model { ExpressionFunction function = function(name); if (function != null) return evaluatorOf(function); + // Check if the name is a signature List<ExpressionFunction> functionsStartingByName = functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList()); - if (functionsStartingByName.size() == 0) - throwUndeterminedFunction("No function '" + name + "' in " + this); - else if (functionsStartingByName.size() == 1) + if (functionsStartingByName.size() == 1) return evaluatorOf(functionsStartingByName.get(0)); - else + if (functionsStartingByName.size() > 1) throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this); + // Check if the name is unambiguous as an output + List<ExpressionFunction> functionsEndingByName = + functions.stream().filter(f -> f.getName().endsWith("." + name)).collect(Collectors.toList()); + if (functionsEndingByName.size() == 1) + return evaluatorOf(functionsEndingByName.get(0)); + if (functionsEndingByName.size() > 1) + throwUndeterminedFunction("Multiple functions called '" + name + "' in " + this); + + // To handle TensorFlow conversion to ONNX + if (name.startsWith("serving_default")) { + return evaluatorOf("default" + name.substring("serving_default".length())); + } + + throwUndeterminedFunction("No function '" + name + "' in " + this); } else if (names.length == 2) { - String name = names[0] + "." + names[1]; - ExpressionFunction function = function(name); - if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this); - return evaluatorOf(function); + return evaluatorOf(names[0] + "." + names[1]); } throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + Arrays.toString(names)); @@ -205,7 +215,7 @@ public class Model { /** Returns a single-use evaluator of a function */ private FunctionEvaluator evaluatorOf(ExpressionFunction function) { - return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy()); + return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy()); } private void throwUndeterminedFunction(String message) { |