summaryrefslogtreecommitdiffstats
path: root/model-inference
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-07-20 16:09:21 +0200
committerJon Bratseth <bratseth@oath.com>2018-07-20 16:09:21 +0200
commit43d5255f1879214796482deea5a5c024a8abf618 (patch)
tree678bca650e7b74553cd83920c1174adb2bdabc0a /model-inference
parente6a8a79025abd0dc4c29d74a4f22687c93f19532 (diff)
Tighten interface
Diffstat (limited to 'model-inference')
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java54
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java17
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/Model.java14
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java27
-rw-r--r--model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java1
-rw-r--r--model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java57
6 files changed, 89 insertions, 81 deletions
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
new file mode 100644
index 00000000000..c4ecc1b25c9
--- /dev/null
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -0,0 +1,54 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * An evaluator which can be used to evaluate a single function once.
+ *
+ * @author bratseth
+ */
+// This wraps all access to the context and the ranking expression to avoid incorrect usage
+public class FunctionEvaluator {
+
+ private final LazyArrayContext context;
+ private boolean evaluated = false;
+
+ FunctionEvaluator(LazyArrayContext context) {
+ this.context = context;
+ }
+
+ /**
+ * Binds the given variable referred in this expression to the given value.
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public FunctionEvaluator bind(String name, Tensor value) {
+ if (evaluated)
+ throw new IllegalStateException("You cannot bind a value in a used evaluator");
+ context.put(name, new TensorValue(value));
+ return this;
+ }
+
+ /**
+ * Binds the given variable referred in this expression to the given value.
+ * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code>
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public FunctionEvaluator bind(String name, double value) {
+ return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build());
+ }
+
+ public Tensor evaluate() {
+ evaluated = true;
+ return context.expression().evaluate(context).asTensor();
+ }
+
+}
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 09da165c170..729d8af01dc 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -17,7 +17,6 @@ import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.LinkedHashSet;
-import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -28,11 +27,11 @@ import java.util.Set;
*/
final class LazyArrayContext extends Context implements ContextIndex {
- private final String expressionName;
+ private final RankingExpression expression;
private final IndexedBindings indexedBindings;
- private LazyArrayContext(String expressionName, IndexedBindings indexedBindings) {
- this.expressionName = expressionName;
+ private LazyArrayContext(RankingExpression expression, IndexedBindings indexedBindings) {
+ this.expression = expression;
this.indexedBindings = indexedBindings.copy(this);
}
@@ -42,7 +41,7 @@ final class LazyArrayContext extends Context implements ContextIndex {
* @param expression the expression to create a context for
*/
LazyArrayContext(RankingExpression expression, Map<String, ExpressionFunction> functions) {
- this.expressionName = expression.getName();
+ this.expression = expression;
this.indexedBindings = new IndexedBindings(expression, functions, this);
}
@@ -118,14 +117,16 @@ final class LazyArrayContext extends Context implements ContextIndex {
}
@Override
- public String toString() { return "context of '" + expressionName + "'"; }
+ public String toString() { return "context of '" + expression.getName() + "'"; }
+
+ RankingExpression expression() { return expression; }
/**
- * Creates a clone of this context suitable for evaluating against the same ranking expression
+ * Creates a copy of this context suitable for evaluating against the same ranking expression
* in a different thread or for re-binding free variables.
*/
LazyArrayContext copy() {
- return new LazyArrayContext(expressionName, indexedBindings);
+ return new LazyArrayContext(expression, indexedBindings);
}
private static class IndexedBindings {
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
index 766e17a7320..9a639d0803f 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -4,7 +4,6 @@ package ai.vespa.models.evaluation;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import java.util.Collection;
import java.util.Collections;
@@ -34,6 +33,7 @@ public class Model {
}
Model(String name, Collection<ExpressionFunction> functions, Collection<ExpressionFunction> referredFunctions) {
+ // TODO: Optimize functions
this.name = name;
this.functions = ImmutableList.copyOf(functions);
@@ -56,7 +56,7 @@ public class Model {
public String name() { return name; }
- /** Returns an immutable list of the free (callable) functions of this */
+ /** Returns an immutable list of the free functions of this */
public List<ExpressionFunction> functions() { return functions; }
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
@@ -89,13 +89,15 @@ public class Model {
Map<String, ExpressionFunction> boundFunctions() { return referredFunctions; }
/**
- * Returns a function which can be used to evaluate the given function
+ * Returns an evaluator which can be used to evaluate the given function in a single thread once.
+
+ * Usage:
+ * <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code>
*
* @throws IllegalArgumentException if the function is not present
*/
- // TODO: Rename to singleThreadedContextFor, move context protottype creation to construction, clone here
- public Context contextFor(String function) {
- return requireContextProprotype(function).copy();
+ public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading?
+ return new FunctionEvaluator(requireContextProprotype(function).copy());
}
@Override
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 44fec1b9f2c..b36e06e5505 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -2,11 +2,6 @@
package ai.vespa.models.evaluation;
import com.google.common.collect.ImmutableMap;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.util.Map;
@@ -14,7 +9,8 @@ import java.util.stream.Collectors;
/**
* Evaluates machine-learned models added to Vespa applications and available as config form.
- * TODO: Write JavaDoc similar to RankingExpression
+ * Usage:
+ * <code>Tensor result = evaluator.bind("foo", value).bind("bar", value").evaluate()</code>
*
* @author bratseth
*/
@@ -34,25 +30,12 @@ public class ModelsEvaluator {
*
* @throws IllegalArgumentException if the function or model is not present
*/
- public Context contextFor(String modelName, String functionName) {
- return requireModel(modelName).contextFor(functionName);
- }
-
- /**
- * Evaluates the given function in the given model.
- *
- * @param modelName the name of the model to evaluate
- * @param functionName the function to evaluate in the model
- * @param context the evaluation context which provides bindings of the arguments to the function
- * @return the tensor resulting from evaluating the model, never null
- * @throws IllegalArgumentException if the model of function is not present
- */
- public Tensor evaluate(String modelName, String functionName, Context context) {
- return requireModel(modelName).requireFunction(functionName).getBody().evaluate(context).asTensor();
+ public FunctionEvaluator evaluatorOf(String modelName, String functionName) {
+ return requireModel(modelName).evaluatorOf(functionName);
}
/** Returns the given model, or throws a IllegalArgumentException if it does not exist */
- public Model requireModel(String name) {
+ Model requireModel(String name) {
Model model = models.get(name);
if (model == null)
throw new IllegalArgumentException("No model named '" + name + ". Available models: " +
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 5d6a7b0c942..bd0453f2826 100644
--- a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import ai.vespa.models.evaluation.Model;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
diff --git a/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index c2f299824c8..38a5a0c9797 100644
--- a/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -3,10 +3,6 @@ package ai.vespa.models.evaluation;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
-import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import org.junit.Test;
@@ -22,55 +18,28 @@ public class ModelsEvaluatorTest {
private static final double delta = 0.00000000001;
- private ModelsEvaluator createEvaluator() {
+ private ModelsEvaluator createModels() {
String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg";
RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig("");
return new ModelsEvaluator(config);
}
@Test
- public void testScalarMapContextEvaluation() {
- ModelsEvaluator evaluator = createEvaluator();
- MapContext context = new MapContext();
- context.put("var1", 3);
- context.put("var2", 5);
- assertEquals(32.0, evaluator.evaluate("macros", "fourtimessum", context).asDouble(), delta);
+ public void testTensorEvaluation() {
+ ModelsEvaluator models = createModels();
+ FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum");
+ function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}"));
+ function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}"));
+ assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), function.evaluate());
}
@Test
- public void testTensorMapContextEvaluation() {
- ModelsEvaluator evaluator = createEvaluator();
- MapContext context = new MapContext();
- context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}")));
- context.put("var2", Value.of(Tensor.from("{{x:0}:7,{x:1}:11}")));
- assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context));
- }
-
- @Test
- public void testScalarArrayContextEvaluation() {
- ModelsEvaluator evaluator = createEvaluator();
- ArrayContext context = new ArrayContext(evaluator.requireModel("macros").requireFunction("fourtimessum").getBody());
- context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}")));
- context.put("var2", Value.of(Tensor.from("{{x:0}:7,{x:1}:11}")));
- assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context));
- }
-
- @Test
- public void testTensorArrayContextEvaluation() {
- ModelsEvaluator evaluator = createEvaluator();
- ArrayContext context = new ArrayContext(evaluator.requireModel("macros").requireFunction("fourtimessum").getBody());
- context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}")));
- context.put("var2", Value.of(Tensor.from("{{x:0}:7,{x:1}:11}")));
- assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context));
- }
-
- @Test
- public void testEvaluationDependingOnBoundMacro() {
- ModelsEvaluator evaluator = createEvaluator();
- Context context = evaluator.contextFor("macros", "secondphase");
- context.put("match", 3);
- context.put("rankBoost", 5);
- assertEquals(32.0, evaluator.evaluate("macros", "secondphase", context).asDouble(), delta);
+ public void testEvaluationDependingOnMacroTakingArguments() {
+ ModelsEvaluator models = createModels();
+ FunctionEvaluator function = models.evaluatorOf("macros", "secondphase");
+ function.bind("match", 3);
+ function.bind("rankBoost", 5);
+ assertEquals(32.0, function.evaluate().asDouble(), delta);
}
}