diff options
author | Lester Solbakken <lesters@oath.com> | 2020-08-24 15:17:01 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-08-24 15:17:01 +0200 |
commit | b62953e2d78aa8d25011d3b007f44d7e75bc5bf6 (patch) | |
tree | 1e20e870576ac6f76f3141561523bd37b2077140 /model-evaluation | |
parent | ef1f0e04884a31f55011374b4fff0dcbe9fa7e30 (diff) |
Model eval: handle unambiguous outputs and models converted from tensorflow to onnx
Diffstat (limited to 'model-evaluation')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 90700acc0d3..03bbb436026 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -129,7 +129,7 @@ public class Model { } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ - private LazyArrayContext requireContextProprotype(String name) { + private LazyArrayContext requireContextPrototype(String name) { LazyArrayContext context = contextPrototypes.get(name); if (context == null) // Implies function is not present throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " + @@ -183,21 +183,31 @@ public class Model { ExpressionFunction function = function(name); if (function != null) return evaluatorOf(function); + // Check if the name is a signature List<ExpressionFunction> functionsStartingByName = functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList()); - if (functionsStartingByName.size() == 0) - throwUndeterminedFunction("No function '" + name + "' in " + this); - else if (functionsStartingByName.size() == 1) + if (functionsStartingByName.size() == 1) return evaluatorOf(functionsStartingByName.get(0)); - else + if (functionsStartingByName.size() > 1) throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this); + // Check if the name is unambiguous as an output + List<ExpressionFunction> functionsEndingByName = + functions.stream().filter(f -> f.getName().endsWith("." + name)).collect(Collectors.toList()); + if (functionsEndingByName.size() == 1) + return evaluatorOf(functionsEndingByName.get(0)); + if (functionsEndingByName.size() > 1) + throwUndeterminedFunction("Multiple functions called '" + name + "' in " + this); + + // To handle TensorFlow conversion to ONNX + if (name.startsWith("serving_default")) { + return evaluatorOf("default" + name.substring("serving_default".length())); + } + + throwUndeterminedFunction("No function '" + name + "' in " + this); } else if (names.length == 2) { - String name = names[0] + "." + names[1]; - ExpressionFunction function = function(name); - if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this); - return evaluatorOf(function); + return evaluatorOf(names[0] + "." + names[1]); } throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + Arrays.toString(names)); @@ -205,7 +215,7 @@ public class Model { /** Returns a single-use evaluator of a function */ private FunctionEvaluator evaluatorOf(ExpressionFunction function) { - return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy()); + return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy()); } private void throwUndeterminedFunction(String message) { |