diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-07-20 16:09:21 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-07-20 16:09:21 +0200 |
commit | 43d5255f1879214796482deea5a5c024a8abf618 (patch) | |
tree | 678bca650e7b74553cd83920c1174adb2bdabc0a /model-inference | |
parent | e6a8a79025abd0dc4c29d74a4f22687c93f19532 (diff) |
Tighten interface
Diffstat (limited to 'model-inference')
6 files changed, 89 insertions, 81 deletions
diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java new file mode 100644 index 00000000000..c4ecc1b25c9 --- /dev/null +++ b/model-inference/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +/** + * An evaluator which can be used to evaluate a single function once. + * + * @author bratseth + */ +// This wraps all access to the context and the ranking expression to avoid incorrect usage +public class FunctionEvaluator { + + private final LazyArrayContext context; + private boolean evaluated = false; + + FunctionEvaluator(LazyArrayContext context) { + this.context = context; + } + + /** + * Binds the given variable referred in this expression to the given value. + * + * @param name the variable to bind + * @param value the value this becomes bound to + * @return this for chaining + */ + public FunctionEvaluator bind(String name, Tensor value) { + if (evaluated) + throw new IllegalStateException("You cannot bind a value in a used evaluator"); + context.put(name, new TensorValue(value)); + return this; + } + + /** + * Binds the given variable referred in this expression to the given value. + * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code> + * + * @param name the variable to bind + * @param value the value this becomes bound to + * @return this for chaining + */ + public FunctionEvaluator bind(String name, double value) { + return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build()); + } + + public Tensor evaluate() { + evaluated = true; + return context.expression().evaluate(context).asTensor(); + } + +} diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 09da165c170..729d8af01dc 100644 --- a/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-inference/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -17,7 +17,6 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; @@ -28,11 +27,11 @@ import java.util.Set; */ final class LazyArrayContext extends Context implements ContextIndex { - private final String expressionName; + private final RankingExpression expression; private final IndexedBindings indexedBindings; - private LazyArrayContext(String expressionName, IndexedBindings indexedBindings) { - this.expressionName = expressionName; + private LazyArrayContext(RankingExpression expression, IndexedBindings indexedBindings) { + this.expression = expression; this.indexedBindings = indexedBindings.copy(this); } @@ -42,7 +41,7 @@ final class LazyArrayContext extends Context implements ContextIndex { * @param expression the expression to create a context for */ LazyArrayContext(RankingExpression expression, Map<String, ExpressionFunction> functions) { - this.expressionName = expression.getName(); + this.expression = expression; this.indexedBindings = new IndexedBindings(expression, functions, this); } @@ -118,14 +117,16 @@ final class LazyArrayContext extends Context implements ContextIndex { } @Override - public String toString() { return "context of '" + expressionName + "'"; } + public String toString() { return "context of '" + expression.getName() + "'"; } + + RankingExpression expression() { return expression; } /** - * Creates a clone of this context suitable for evaluating against the same ranking expression + * Creates a copy of this context suitable for evaluating against the same ranking expression * in a different thread or for re-binding free variables. */ LazyArrayContext copy() { - return new LazyArrayContext(expressionName, indexedBindings); + return new LazyArrayContext(expression, indexedBindings); } private static class IndexedBindings { diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java index 766e17a7320..9a639d0803f 100644 --- a/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-inference/src/main/java/ai/vespa/models/evaluation/Model.java @@ -4,7 +4,6 @@ package ai.vespa.models.evaluation; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; import java.util.Collection; import java.util.Collections; @@ -34,6 +33,7 @@ public class Model { } Model(String name, Collection<ExpressionFunction> functions, Collection<ExpressionFunction> referredFunctions) { + // TODO: Optimize functions this.name = name; this.functions = ImmutableList.copyOf(functions); @@ -56,7 +56,7 @@ public class Model { public String name() { return name; } - /** Returns an immutable list of the free (callable) functions of this */ + /** Returns an immutable list of the free functions of this */ public List<ExpressionFunction> functions() { return functions; } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ @@ -89,13 +89,15 @@ public class Model { Map<String, ExpressionFunction> boundFunctions() { return referredFunctions; } /** - * Returns a function which can be used to evaluate the given function + * Returns an evaluator which can be used to evaluate the given function in a single thread once. + + * Usage: + * <code>Tensor result = model.evaluatorOf("myFunction").bind("foo", value).bind("bar", value).evaluate()</code> * * @throws IllegalArgumentException if the function is not present */ - // TODO: Rename to singleThreadedContextFor, move context protottype creation to construction, clone here - public Context contextFor(String function) { - return requireContextProprotype(function).copy(); + public FunctionEvaluator evaluatorOf(String function) { // TODO: Parameter overloading? + return new FunctionEvaluator(requireContextProprotype(function).copy()); } @Override diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 44fec1b9f2c..b36e06e5505 100644 --- a/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-inference/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -2,11 +2,6 @@ package ai.vespa.models.evaluation; import com.google.common.collect.ImmutableMap; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; import java.util.Map; @@ -14,7 +9,8 @@ import java.util.stream.Collectors; /** * Evaluates machine-learned models added to Vespa applications and available as config form. - * TODO: Write JavaDoc similar to RankingExpression + * Usage: + * <code>Tensor result = evaluator.bind("foo", value).bind("bar", value").evaluate()</code> * * @author bratseth */ @@ -34,25 +30,12 @@ public class ModelsEvaluator { * * @throws IllegalArgumentException if the function or model is not present */ - public Context contextFor(String modelName, String functionName) { - return requireModel(modelName).contextFor(functionName); - } - - /** - * Evaluates the given function in the given model. - * - * @param modelName the name of the model to evaluate - * @param functionName the function to evaluate in the model - * @param context the evaluation context which provides bindings of the arguments to the function - * @return the tensor resulting from evaluating the model, never null - * @throws IllegalArgumentException if the model of function is not present - */ - public Tensor evaluate(String modelName, String functionName, Context context) { - return requireModel(modelName).requireFunction(functionName).getBody().evaluate(context).asTensor(); + public FunctionEvaluator evaluatorOf(String modelName, String functionName) { + return requireModel(modelName).evaluatorOf(functionName); } /** Returns the given model, or throws a IllegalArgumentException if it does not exist */ - public Model requireModel(String name) { + Model requireModel(String name) { Model model = models.get(name); if (model == null) throw new IllegalArgumentException("No model named '" + name + ". Available models: " + diff --git a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 5d6a7b0c942..bd0453f2826 100644 --- a/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-inference/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import ai.vespa.models.evaluation.Model; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; diff --git a/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index c2f299824c8..38a5a0c9797 100644 --- a/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-inference/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,10 +3,6 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; -import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; import org.junit.Test; @@ -22,55 +18,28 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; - private ModelsEvaluator createEvaluator() { + private ModelsEvaluator createModels() { String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); return new ModelsEvaluator(config); } @Test - public void testScalarMapContextEvaluation() { - ModelsEvaluator evaluator = createEvaluator(); - MapContext context = new MapContext(); - context.put("var1", 3); - context.put("var2", 5); - assertEquals(32.0, evaluator.evaluate("macros", "fourtimessum", context).asDouble(), delta); + public void testTensorEvaluation() { + ModelsEvaluator models = createModels(); + FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); + function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); + function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); + assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), function.evaluate()); } @Test - public void testTensorMapContextEvaluation() { - ModelsEvaluator evaluator = createEvaluator(); - MapContext context = new MapContext(); - context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}"))); - context.put("var2", Value.of(Tensor.from("{{x:0}:7,{x:1}:11}"))); - assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context)); - } - - @Test - public void testScalarArrayContextEvaluation() { - ModelsEvaluator evaluator = createEvaluator(); - ArrayContext context = new ArrayContext(evaluator.requireModel("macros").requireFunction("fourtimessum").getBody()); - context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}"))); - context.put("var2", Value.of(Tensor.from("{{x:0}:7,{x:1}:11}"))); - assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context)); - } - - @Test - public void testTensorArrayContextEvaluation() { - ModelsEvaluator evaluator = createEvaluator(); - ArrayContext context = new ArrayContext(evaluator.requireModel("macros").requireFunction("fourtimessum").getBody()); - context.put("var1", Value.of(Tensor.from("{{x:0}:3,{x:1}:5}"))); - context.put("var2", Value.of(Tensor.from("{{x:0}:7,{x:1}:11}"))); - assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), evaluator.evaluate("macros", "fourtimessum", context)); - } - - @Test - public void testEvaluationDependingOnBoundMacro() { - ModelsEvaluator evaluator = createEvaluator(); - Context context = evaluator.contextFor("macros", "secondphase"); - context.put("match", 3); - context.put("rankBoost", 5); - assertEquals(32.0, evaluator.evaluate("macros", "secondphase", context).asDouble(), delta); + public void testEvaluationDependingOnMacroTakingArguments() { + ModelsEvaluator models = createModels(); + FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); + function.bind("match", 3); + function.bind("rankBoost", 5); + assertEquals(32.0, function.evaluate().asDouble(), delta); } } |