summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-25 22:58:16 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-25 22:58:16 +0200
commitb525b8d8efcf71b421db1e549e4f078514e26135 (patch)
tree1011ce314160c766e119a42c67daf6bc35980fe4 /model-evaluation
parentccda281b6c60de0e6c7108a8532d7f7438ebd9ae (diff)
Improve evaluation API
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java49
2 files changed, 46 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 3b50cef6e2e..00fcad94ce8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -22,7 +22,7 @@ import java.util.regex.Pattern;
class FunctionReference {
private static final Pattern referencePattern =
- Pattern.compile("rankingExpression\\(([a-zA-Z0-9_]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
/** The name of the function referenced */
private final String name;
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ca739195867..d8b7e82677c 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@@ -116,14 +117,54 @@ public class Model {
/**
* Returns an evaluator which can be used to evaluate the given function in a single thread once.
-
+ *
* Usage:
* <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code>
*
- * @throws IllegalArgumentException if the function is not present
+ * @param names the names identifying the function - this can be from 0 to 2, specifying function or "signature"
+ * name, and "output", respectively. Names which are unnecessary to determine the desired function
+ * uniquely (e.g if there is just one function or output) can be omitted.
+ * @throws IllegalArgumentException if the function is not present, or not uniquely identified by the names given
*/
- public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading?
- return new FunctionEvaluator(requireFunction(function), requireContextProprotype(function).copy());
+ public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading?
+ if (names.length == 0) {
+ if (functions.size() > 1)
+ throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given");
+ return evaluatorOf(functions.get(0));
+ }
+ else if (names.length == 1) {
+ String name = names[0];
+ ExpressionFunction function = function(name);
+ if (function != null) return evaluatorOf(function);
+
+ List<ExpressionFunction> functionsStartingByName =
+ functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList());
+ if (functionsStartingByName.size() == 0)
+ throwUndeterminedFunction("No function '" + name + "' in " + this);
+ else if (functionsStartingByName.size() == 1)
+ return evaluatorOf(functionsStartingByName.get(0));
+ else
+ throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this);
+
+ }
+ else if (names.length == 2) {
+ String name = names[0] + "." + names[1];
+ ExpressionFunction function = function(name);
+ if (function == null) throwUndeterminedFunction("No function '" + name + "' in " + this);
+ return evaluatorOf(function);
+ }
+ throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " +
+ Arrays.toString(names));
+ }
+
+ /** Returns a single-use evaluator of a function */
+ private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
+ return new FunctionEvaluator(function, requireContextProprotype(function.getName()).copy());
+ }
+
+ private void throwUndeterminedFunction(String message) {
+ throw new IllegalArgumentException(message + ". Available functions: " +
+ functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
}
@Override