diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-25 22:58:16 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-25 22:58:16 +0200 |
commit | b525b8d8efcf71b421db1e549e4f078514e26135 (patch) | |
tree | 1011ce314160c766e119a42c67daf6bc35980fe4 /model-evaluation | |
parent | ccda281b6c60de0e6c7108a8532d7f7438ebd9ae (diff) |
Improve evaluation API
Diffstat (limited to 'model-evaluation')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java | 2 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 49 |
2 files changed, 46 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 3b50cef6e2e..00fcad94ce8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -22,7 +22,7 @@ import java.util.regex.Pattern; class FunctionReference { private static final Pattern referencePattern = - Pattern.compile("rankingExpression\\(([a-zA-Z0-9_]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); /** The name of the function referenced */ private final String name; diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ca739195867..d8b7e82677c 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -116,14 +117,54 @@ public class Model { /** * Returns an evaluator which can be used to evaluate the given function in a single thread once. - + * * Usage: * <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code> * - * @throws IllegalArgumentException if the function is not present + * @param names the names identifying the function - this can be from 0 to 2, specifying function or "signature" + * name, and "output", respectively. Names which are unnecessary to determine the desired function + * uniquely (e.g if there is just one function or output) can be omitted. + * @throws IllegalArgumentException if the function is not present, or not uniquely identified by the names given */ - public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading? - return new FunctionEvaluator(requireFunction(function), requireContextProprotype(function).copy()); + public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading? + if (names.length == 0) { + if (functions.size() > 1) + throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given"); + return evaluatorOf(functions.get(0)); + } + else if (names.length == 1) { + String name = names[0]; + ExpressionFunction function = function(name); + if (function != null) return evaluatorOf(function); + + List<ExpressionFunction> functionsStartingByName = + functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList()); + if (functionsStartingByName.size() == 0) + throwUndeterminedFunction("No function '" + name + "' in " + this); + else if (functionsStartingByName.size() == 1) + return evaluatorOf(functionsStartingByName.get(0)); + else + throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this); + + } + else if (names.length == 2) { + String name = names[0] + "." + names[1]; + ExpressionFunction function = function(name); + if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this); + return evaluatorOf(function); + } + throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + + Arrays.toString(names)); + } + + /** Returns a single-use evaluator of a function */ + private FunctionEvaluator evaluatorOf(ExpressionFunction function) { + return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy()); + } + + private void throwUndeterminedFunction(String message) { + throw new IllegalArgumentException(message + ". Available functions: " + + functions.stream().map(f -> f.getName()).collect(Collectors.joining(", "))); } @Override |