aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java7
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
2 files changed, 27 insertions, 10 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 36b11cee067..9a88b2a31f6 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -142,8 +142,15 @@ public class ModelEvaluationTest {
assertNotNull(onnx_mnist_softmax.evaluatorOf("default"));
assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add"));
assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add"));
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("add"));
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default"));
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default", "add"));
+ assertNotNull(onnx_mnist_softmax.evaluatorOf("serving_default.add"));
assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add"));
assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add"));
+ assertNotNull(evaluator.evaluatorOf("mnist_softmax", "add"));
+ assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default.add"));
+ assertNotNull(evaluator.evaluatorOf("mnist_softmax", "serving_default", "add"));
assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), onnx_mnist_softmax.functions().get(0).argumentTypes().get("Placeholder"));
Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved");
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 90700acc0d3..03bbb436026 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -129,7 +129,7 @@ public class Model {
}
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
- private LazyArrayContext requireContextProprotype(String name) {
+ private LazyArrayContext requireContextPrototype(String name) {
LazyArrayContext context = contextPrototypes.get(name);
if (context == null) // Implies function is not present
throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " +
@@ -183,21 +183,31 @@ public class Model {
ExpressionFunction function = function(name);
if (function != null) return evaluatorOf(function);
+ // Check if the name is a signature
List<ExpressionFunction> functionsStartingByName =
functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList());
- if (functionsStartingByName.size() == 0)
- throwUndeterminedFunction("No function '" + name + "' in " + this);
- else if (functionsStartingByName.size() == 1)
+ if (functionsStartingByName.size() == 1)
return evaluatorOf(functionsStartingByName.get(0));
- else
+ if (functionsStartingByName.size() > 1)
throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this);
+ // Check if the name is unambiguous as an output
+ List<ExpressionFunction> functionsEndingByName =
+ functions.stream().filter(f -> f.getName().endsWith("." + name)).collect(Collectors.toList());
+ if (functionsEndingByName.size() == 1)
+ return evaluatorOf(functionsEndingByName.get(0));
+ if (functionsEndingByName.size() > 1)
+ throwUndeterminedFunction("Multiple functions called '" + name + "' in " + this);
+
+ // To handle TensorFlow conversion to ONNX
+ if (name.startsWith("serving_default")) {
+ return evaluatorOf("default" + name.substring("serving_default".length()));
+ }
+
+ throwUndeterminedFunction("No function '" + name + "' in " + this);
}
else if (names.length == 2) {
- String name = names[0] + "." + names[1];
- ExpressionFunction function = function(name);
- if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this);
- return evaluatorOf(function);
+ return evaluatorOf(names[0] + "." + names[1]);
}
throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " +
Arrays.toString(names));
@@ -205,7 +215,7 @@ public class Model {
/** Returns a single-use evaluator of a function */
private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
- return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy());
+ return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy());
}
private void throwUndeterminedFunction(String message) {