aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-08-24 15:17:01 +0200
committerLester Solbakken <lesters@oath.com>2020-08-24 15:17:01 +0200
commitb62953e2d78aa8d25011d3b007f44d7e75bc5bf6 (patch)
tree1e20e870576ac6f76f3141561523bd37b2077140 /model-evaluation
parentef1f0e04884a31f55011374b4fff0dcbe9fa7e30 (diff)
Model eval: handle unambiguous outputs and models converted from tensorflow to onnx
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
1 files changed, 20 insertions, 10 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 90700acc0d3..03bbb436026 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -129,7 +129,7 @@ public class Model {
}
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
- private LazyArrayContext requireContextProprotype(String name) {
+ private LazyArrayContext requireContextPrototype(String name) {
LazyArrayContext context = contextPrototypes.get(name);
if (context == null) // Implies function is not present
throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " +
@@ -183,21 +183,31 @@ public class Model {
ExpressionFunction function = function(name);
if (function != null) return evaluatorOf(function);
+ // Check if the name is a signature
List<ExpressionFunction> functionsStartingByName =
functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList());
- if (functionsStartingByName.size() == 0)
- throwUndeterminedFunction("No function '" + name + "' in " + this);
- else if (functionsStartingByName.size() == 1)
+ if (functionsStartingByName.size() == 1)
return evaluatorOf(functionsStartingByName.get(0));
- else
+ if (functionsStartingByName.size() > 1)
throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this);
+ // Check if the name is unambiguous as an output
+ List<ExpressionFunction> functionsEndingByName =
+ functions.stream().filter(f -> f.getName().endsWith("." + name)).collect(Collectors.toList());
+ if (functionsEndingByName.size() == 1)
+ return evaluatorOf(functionsEndingByName.get(0));
+ if (functionsEndingByName.size() > 1)
+ throwUndeterminedFunction("Multiple functions called '" + name + "' in " + this);
+
+ // To handle TensorFlow conversion to ONNX
+ if (name.startsWith("serving_default")) {
+ return evaluatorOf("default" + name.substring("serving_default".length()));
+ }
+
+ throwUndeterminedFunction("No function '" + name + "' in " + this);
}
else if (names.length == 2) {
- String name = names[0] + "." + names[1];
- ExpressionFunction function = function(name);
- if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this);
- return evaluatorOf(function);
+ return evaluatorOf(names[0] + "." + names[1]);
}
throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " +
Arrays.toString(names));
@@ -205,7 +215,7 @@ public class Model {
/** Returns a single-use evaluator of a function */
private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
- return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy());
+ return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy());
}
private void throwUndeterminedFunction(String message) {