diff options
author | Lester Solbakken <lesters@oath.com> | 2021-11-24 12:47:23 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-11-24 12:47:23 +0100 |
commit | a40eea2df1d64d8586768f1122da90a5756bef10 (patch) | |
tree | ece85a26fe47404a2dfeda1c273640b0bc0f7334 | |
parent | bf026551a16d12d4b4e75949933d080f39a85eef (diff) |
Remove MultiFunctionEvaluator
9 files changed, 125 insertions, 225 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 70c4cb942bc..8ed229b2ff5 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -3,12 +3,9 @@ package com.yahoo.vespa.model.container.ml; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.ModelsEvaluator; -import ai.vespa.models.evaluation.MultiFunctionEvaluator; import com.yahoo.tensor.Tensor; import org.junit.Test; -import java.util.Map; - import static org.junit.Assert.assertEquals; /** @@ -30,10 +27,11 @@ public class ModelsEvaluatorTest { Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); - MultiFunctionEvaluator eval = modelsEvaluator.multiEvaluatorOf("mul"); - Map<String, Tensor> out = eval.bind("input1", input1).bind("input2", input2).evaluate(); - assertEquals(6.0, out.get("output1").sum().asDouble(), 1e-9); - assertEquals(5.0, out.get("output2").sum().asDouble(), 1e-9); + FunctionEvaluator eval = modelsEvaluator.evaluatorOf("mul"); + output = eval.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, output.sum().asDouble(), 1e-9); + assertEquals(6.0, eval.result("output1").sum().asDouble(), 1e-9); + assertEquals(5.0, eval.result("output2").sum().asDouble(), 1e-9); // LightGBM model evaluation FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression"); diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 3f23e7456ad..71dd7ffc2eb 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -6,6 +6,7 @@ "public" ], "methods": [ + "public com.yahoo.tensor.Tensor result(java.lang.String)", "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)", "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, double)", "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, java.lang.String)", @@ -13,7 +14,8 @@ "public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(double)", "public com.yahoo.tensor.Tensor evaluate()", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction function()", - "public ai.vespa.models.evaluation.LazyArrayContext context()" + "public ai.vespa.models.evaluation.LazyArrayContext context()", + "public java.util.List outputs()" ], "fields": [] }, @@ -56,7 +58,6 @@ "public java.lang.String name()", "public java.util.List functions()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])", - "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String[])", "public java.lang.String toString()" ], "fields": [] @@ -73,25 +74,10 @@ "public void <init>(java.util.Map)", "public java.util.Map models()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])", - "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String, java.lang.String[])", "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)" ], "fields": [] }, - "ai.vespa.models.evaluation.MultiFunctionEvaluator": { - "superClass": "java.lang.Object", - "interfaces": [], - "attributes": [ - "public" - ], - "methods": [ - "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)", - "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, double)", - "public java.util.Map evaluate()", - "public java.util.List functions()" - ], - "fields": [] - }, "ai.vespa.models.evaluation.RankProfilesConfigImporter": { "superClass": "java.lang.Object", "interfaces": [], diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index aa13cb96845..7a992cb7aa9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -8,24 +8,37 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** - * An evaluator which can be used to evaluate a single function once. + * An evaluator which can be used to evaluate a function once. * * @author bratseth */ // This wraps all access to the context and the ranking expression to avoid incorrect usage public class FunctionEvaluator { - private final ExpressionFunction function; - private final LazyArrayContext context; + private final List<ExpressionFunction> functions; + private final Map<String, LazyArrayContext> contexts; + private final Map<String, Tensor> results; private boolean evaluated = false; FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) { - this.function = function; - this.context = context; + this(List.of(function), Map.of(function.getName(), context)); + } + + FunctionEvaluator(List<ExpressionFunction> functions, Map<String, LazyArrayContext> contexts) { + this.functions = List.copyOf(functions); + this.contexts = Map.copyOf(contexts); + this.results = new HashMap<>(); + } + + public Tensor result(String name) { + return results.get(name); } /** @@ -38,15 +51,14 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - TensorType requiredType = function.argumentTypes().get(name); - if (requiredType == null) - throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + - ". Expected arguments: " + function.argumentTypes().entrySet().stream() - .map(e -> e.getKey() + ": " + e.getValue()) - .collect(Collectors.joining(", "))); - if ( ! value.type().isAssignableTo(requiredType)) - throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); - context.put(name, new TensorValue(value)); + for (ExpressionFunction function : functions) { + if (function.argumentTypes().containsKey(name)) { + TensorType requiredType = function.argumentTypes().get(name); + if ( ! value.type().isAssignableTo(requiredType)) + throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); + contexts.get(function.getName()).put(name, new TensorValue(value)); + } + } return this; } @@ -73,7 +85,11 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, String value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - context.put(name, new StringValue(value)); + for (ExpressionFunction function : functions) { + if (function.argumentTypes().containsKey(name)) { + contexts.get(function.getName()).put(name, new StringValue(value)); + } + } return this; } @@ -86,7 +102,9 @@ public class FunctionEvaluator { public FunctionEvaluator setMissingValue(Tensor value) { if (evaluated) throw new IllegalStateException("Cannot change the missing value in a used evaluator"); - context.setMissingValue(value); + for (LazyArrayContext context : contexts.values()) { + context.setMissingValue(value); + } return this; } @@ -102,18 +120,31 @@ public class FunctionEvaluator { public Tensor evaluate() { checkArguments(); - evaluated = true; evaluateOnnxModels(); - return function.getBody().evaluate(context).asTensor(); + + Tensor defaultResult = null; + for (ExpressionFunction function: functions) { + LazyArrayContext context = contexts.get(function.getName()); + Tensor result = function.getBody().evaluate(context).asTensor(); + results.put(function.getName(), function.getBody().evaluate(context).asTensor()); + if (defaultResult == null) { + defaultResult = result; + } + } + evaluated = true; + return defaultResult; } void checkArguments() { - for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - checkArgument(argument.getKey(), argument.getValue()); + for (ExpressionFunction function : functions) { + LazyArrayContext context = contexts.get(function.getName()); + for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { + checkArgument(argument.getKey(), argument.getValue(), context); + } } } - private void checkArgument(String name, TensorType type) { + private void checkArgument(String name, TensorType type, LazyArrayContext context) { if (context.isMissing(name)) throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type); if (! context.get(name).type().isAssignableTo(type)) @@ -124,23 +155,52 @@ public class FunctionEvaluator { * Evaluate ONNX models (if not already evaluated) and add the result back to the context. */ private void evaluateOnnxModels() { - for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) { - String onnxFeature = entry.getKey(); - OnnxModel onnxModel = entry.getValue(); - if (context.get(onnxFeature).equals(context.defaultValue())) { - Map<String, Tensor> inputs = new HashMap<>(); - for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) { - inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); + Set<OnnxModel> onnxModels = new HashSet<>(); + for (LazyArrayContext context : contexts.values()) { + onnxModels.addAll(context.onnxModels().values()); + } + + for (OnnxModel onnxModel : onnxModels) { + + // Gather inputs from all functions. Inputs with the same name must have the same value. + Map<String, Tensor> inputs = new HashMap<>(); + for (LazyArrayContext context : contexts.values()) { + for (OnnxModel functionModel : context.onnxModels().values()) { + if (functionModel.name().equals(onnxModel.name())) { + for (String inputName: onnxModel.inputs().keySet()) { + inputs.put(inputName, context.get(inputName).asTensor()); + } + } } - Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model - context.put(onnxFeature, new TensorValue(result)); } + + // Evaluate model once. + Map<String, Tensor> outputs = onnxModel.evaluate(inputs); + + // Add outputs back to the context of the functions that need them; they won't be recalculated. + for (ExpressionFunction function : functions) { + LazyArrayContext context = contexts.get(function.getName()); + for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel functionModel = entry.getValue(); + if (functionModel.name().equals(onnxModel.name())) { + Tensor result = outputs.get(function.getName()); // Function name is output of model + context.put(onnxFeature, new TensorValue(result)); + } + } + } + } } - /** Returns the function evaluated by this */ - public ExpressionFunction function() { return function; } + /** Returns the default function evaluated by this */ + public ExpressionFunction function() { return functions.get(0); } - public LazyArrayContext context() { return context; } + public LazyArrayContext context() { return contexts.get(function().getName()); } + + /** Returns the names of the outputs of this function */ + public List<String> outputs() { + return functions.stream().map(ExpressionFunction::getName).collect(Collectors.toList()); + } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d97235d11d2..cc53f38f800 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -153,7 +153,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The mapping from variable name to index */ private final ImmutableMap<String, Integer> nameToIndex; - /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ + /** The names which needs to be bound externally when invoking this (i.e. not constant or invocation) */ private final ImmutableSet<String> arguments; /** The current values set */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 84ab6e81840..ab24986e542 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -181,9 +182,7 @@ public class Model { */ public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading? if (names.length == 0) { - if (functions.size() > 1) - throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given"); - return evaluatorOf(functions.get(0)); + return evaluatorOf(functions); } else if (names.length == 1) { String name = names[0]; @@ -230,20 +229,13 @@ public class Model { return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy()); } - /** - * Returns an evaluator which can be used to evaluate the given model in a single thread once. - * - * @param names The names identifying the outputs. If none are passed, evaluates all outputs. - * @throws IllegalArgumentException if the function is not present. - */ - public MultiFunctionEvaluator multiEvaluatorOf(String ... names) { - List<FunctionEvaluator> evaluators; - if (names.length == 0) { - evaluators = functions.stream().map(this::evaluatorOf).collect(Collectors.toList()); - } else { - evaluators = Arrays.stream(names).map(this::evaluatorOf).collect(Collectors.toList()); + /** Returns a single-use evaluator of a function */ + private FunctionEvaluator evaluatorOf(List<ExpressionFunction> functions) { + Map<String, LazyArrayContext> contexts = new HashMap<>(); + for (ExpressionFunction function : functions) { + contexts.put(function.getName(), requireContextPrototype(function.getName()).copy()); } - return new MultiFunctionEvaluator(evaluators); + return new FunctionEvaluator(functions, contexts); } private void throwUndeterminedFunction(String message) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index bd00f5510c6..01427ca811a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -60,17 +60,6 @@ public class ModelsEvaluator extends AbstractComponent { return requireModel(modelName).evaluatorOf(names); } - /** - * Returns a model evaluator which can be used to evaluate multiple functions in a model - * - * @param modelName the name of the model - * @param names the names of the outputs to evaluate, or none if all should be evaluated - * @throws IllegalArgumentException if the function or model is not present - */ - public MultiFunctionEvaluator multiEvaluatorOf(String modelName, String ... names) { - return requireModel(modelName).multiEvaluatorOf(names); - } - /** Returns the given model, or throws a IllegalArgumentException if it does not exist */ public Model requireModel(String name) { Model model = models.get(name); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java deleted file mode 100644 index 53d470ecc19..00000000000 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.models.evaluation; - -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * An evaluator which can be used to evaluate a model with multiple outputs. - * This will ensure that ONNX models are only evaluated once. - * - * @author lesters - */ -public class MultiFunctionEvaluator { - - private final List<FunctionEvaluator> functions; - private boolean evaluated = false; - - MultiFunctionEvaluator(List<FunctionEvaluator> functions) { - this.functions = functions; - } - - /** - * Binds the given variable referred in this expression to the given value. - * - * @param name the variable to bind - * @param value the value this becomes bound to - * @return this for chaining - */ - public MultiFunctionEvaluator bind(String name, Tensor value) { - if (evaluated) - throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - for (FunctionEvaluator function : functions) { - if (function.function().argumentTypes().containsKey(name)) { - function.bind(name, value); // only bind input to the functions that need them - } - } - return this; - } - - /** - * Binds the given variable referred in this expression to the given value. - * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code> - * - * @param name the variable to bind - * @param value the value this becomes bound to - * @return this for chaining - */ - public MultiFunctionEvaluator bind(String name, double value) { - return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build()); - } - - public Map<String, Tensor> evaluate() { - for (FunctionEvaluator function : functions) { - function.checkArguments(); - } - - evaluateOnnxModels(); // evaluate each ONNX model only once - - Map<String, Tensor> results = new HashMap<>(); - for (FunctionEvaluator function : functions) { - results.put(function.function().getName(), function.evaluate()); - } - evaluated = true; - return results; - } - - /** - * Evaluate all ONNX models across all functions once and add the result - * back to the functions' context. - */ - private void evaluateOnnxModels() { - Set<OnnxModel> onnxModels = new HashSet<>(); - for (FunctionEvaluator function : functions) { - onnxModels.addAll(function.context().onnxModels().values()); - } - - for (OnnxModel onnxModel : onnxModels) { - - // Gather inputs from all functions. Inputs with the same name must have the same value. - Map<String, Tensor> inputs = new HashMap<>(); - for (FunctionEvaluator function : functions) { - for (OnnxModel functionModel : function.context().onnxModels().values()) { - if (functionModel.name().equals(onnxModel.name())) { - for (String inputName: onnxModel.inputs().keySet()) { - inputs.put(inputName, function.context().get(inputName).asTensor()); - } - } - } - } - - // Evaluate model once. - Map<String, Tensor> outputs = onnxModel.evaluate(inputs); - - // Add outputs back to the context of the functions that need them; they won't be recalculated. - for (FunctionEvaluator function : functions) { - for (Map.Entry<String, OnnxModel> entry : function.context().onnxModels().entrySet()) { - String onnxFeature = entry.getKey(); - OnnxModel functionModel = entry.getValue(); - if (functionModel.name().equals(onnxModel.name())) { - Tensor result = outputs.get(function.function().getName()); // Function name is output of model - function.context().put(onnxFeature, new TensorValue(result)); - } - } - } - - } - } - - public List<FunctionEvaluator> functions() { - return functions; - } - -} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 4cb52216137..3e065d25ad2 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -95,8 +95,8 @@ public class ModelsEvaluatorTest { evaluator.bind("argNone", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}")); evaluator.evaluate(); } - catch (IllegalArgumentException e) { - assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])", + catch (IllegalStateException e) { + assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", Exceptions.toMessageString(e)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java index 59ab378e43a..ae77af264a1 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -35,27 +35,22 @@ public class OnnxEvaluatorTest { assertTrue(models.models().containsKey("add_mul")); assertTrue(models.models().containsKey("one_layer")); + Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]"); + Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]"); + FunctionEvaluator function = models.evaluatorOf("add_mul", "output1"); - function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]")); - function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]")); - assertEquals(6.0, function.evaluate().sum().asDouble(), delta); + Tensor result = function.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, result.sum().asDouble(), delta); function = models.evaluatorOf("add_mul", "output2"); - function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]")); - function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]")); - assertEquals(5.0, function.evaluate().sum().asDouble(), delta); - - MultiFunctionEvaluator evaluator = models.multiEvaluatorOf("add_mul"); - Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]"); - Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]"); - Map<String, Tensor> result = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); - assertEquals(6.0, result.get("output1").sum().asDouble(), delta); - assertEquals(5.0, result.get("output2").sum().asDouble(), delta); - - evaluator = models.multiEvaluatorOf("add_mul", "output1"); - result = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); - assertTrue("Result does not contain requested output", result.containsKey("output1")); - assertFalse("Result contains output that was not requested", result.containsKey("output2")); + result = function.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(5.0, result.sum().asDouble(), delta); + + function = models.evaluatorOf("add_mul"); // contains two models + result = function.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, result.sum().asDouble(), delta); + assertEquals(6.0, function.result("output1").sum().asDouble(), delta); + assertEquals(5.0, function.result("output2").sum().asDouble(), delta); function = models.evaluatorOf("one_layer"); function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]")); |