diff options
author | Lester Solbakken <lesters@oath.com> | 2021-11-22 15:41:12 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-11-22 15:41:12 +0100 |
commit | bf026551a16d12d4b4e75949933d080f39a85eef (patch) | |
tree | f74bcec8d1c82b4c1f49f07e80c31b6750450167 | |
parent | ea7d530c80515724f8384ded392da5b9a3ab3741 (diff) |
Only evaluate ONNX models once in stateless model eval
10 files changed, 221 insertions, 18 deletions
diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.onnx b/config-model/src/test/cfg/application/stateless_eval/mul.onnx index 087e2c3427f..26411c96986 100644 --- a/config-model/src/test/cfg/application/stateless_eval/mul.onnx +++ b/config-model/src/test/cfg/application/stateless_eval/mul.onnx @@ -1,7 +1,10 @@ -mul.py:f - +mul.py:Ÿ + input1 -input2output"MulmulZ +input2output1"Mul + +input1 +input2output2"AddmulZ input1 @@ -9,8 +12,12 @@ input2 -b -output +b +output1 + + +b +output2 B
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.py b/config-model/src/test/cfg/application/stateless_eval/mul.py index 9fcb8612af9..6bbc4e23200 100755 --- a/config-model/src/test/cfg/application/stateless_eval/mul.py +++ b/config-model/src/test/cfg/application/stateless_eval/mul.py @@ -2,25 +2,31 @@ import onnx from onnx import helper, TensorProto -INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) -INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) -OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) +INPUT1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +INPUT2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +OUTPUT1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1]) +OUTPUT2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1]) nodes = [ helper.make_node( 'Mul', ['input1', 'input2'], - ['output'], + ['output1'], + ), + helper.make_node( + 'Add', + ['input1', 'input2'], + ['output2'], ), ] graph_def = helper.make_graph( nodes, 'mul', [ - INPUT_1, - INPUT_2 + INPUT1, + INPUT2 ], - [OUTPUT], + [OUTPUT1, OUTPUT2], ) model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'mul.onnx') diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 5630d3cc186..70c4cb942bc 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -3,9 +3,12 @@ package com.yahoo.vespa.model.container.ml; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.ModelsEvaluator; +import ai.vespa.models.evaluation.MultiFunctionEvaluator; import com.yahoo.tensor.Tensor; import org.junit.Test; +import java.util.Map; + import static org.junit.Assert.assertEquals; /** @@ -21,12 +24,17 @@ public class ModelsEvaluatorTest { assertEquals(3, modelsEvaluator.models().size()); // ONNX model evaluation - FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul"); + FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul", "output1"); Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]"); Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]"); Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); + MultiFunctionEvaluator eval = modelsEvaluator.multiEvaluatorOf("mul"); + Map<String, Tensor> out = eval.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, out.get("output1").sum().asDouble(), 1e-9); + assertEquals(5.0, out.get("output2").sum().asDouble(), 1e-9); + // LightGBM model evaluation FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression"); lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i"); diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 6728d5cd9b4..3f23e7456ad 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -56,6 +56,7 @@ "public java.lang.String name()", "public java.util.List functions()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])", + "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String[])", "public java.lang.String toString()" ], "fields": [] @@ -72,10 +73,25 @@ "public void <init>(java.util.Map)", "public java.util.Map models()", "public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])", + "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String, java.lang.String[])", "public ai.vespa.models.evaluation.Model requireModel(java.lang.String)" ], "fields": [] }, + "ai.vespa.models.evaluation.MultiFunctionEvaluator": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)", + "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, double)", + "public java.util.Map evaluate()", + "public java.util.List functions()" + ], + "fields": [] + }, "ai.vespa.models.evaluation.RankProfilesConfigImporter": { "superClass": "java.lang.Object", "interfaces": [], diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 6af33e29e62..aa13cb96845 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -101,14 +101,18 @@ public class FunctionEvaluator { } public Tensor evaluate() { - for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - checkArgument(argument.getKey(), argument.getValue()); - } + checkArguments(); evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); } + void checkArguments() { + for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { + checkArgument(argument.getKey(), argument.getValue()); + } + } + private void checkArgument(String name, TensorType type) { if (context.isMissing(name)) throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 8af5f7bc499..84ab6e81840 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -6,9 +6,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -232,6 +230,22 @@ public class Model { return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy()); } + /** + * Returns an evaluator which can be used to evaluate the given model in a single thread once. + * + * @param names The names identifying the outputs. If none are passed, evaluates all outputs. + * @throws IllegalArgumentException if the function is not present. + */ + public MultiFunctionEvaluator multiEvaluatorOf(String ... names) { + List<FunctionEvaluator> evaluators; + if (names.length == 0) { + evaluators = functions.stream().map(this::evaluatorOf).collect(Collectors.toList()); + } else { + evaluators = Arrays.stream(names).map(this::evaluatorOf).collect(Collectors.toList()); + } + return new MultiFunctionEvaluator(evaluators); + } + private void throwUndeterminedFunction(String message) { throw new IllegalArgumentException(message + ". Available functions: " + functions.stream().map(f -> f.getName()).collect(Collectors.joining(", "))); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 01427ca811a..bd00f5510c6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -60,6 +60,17 @@ public class ModelsEvaluator extends AbstractComponent { return requireModel(modelName).evaluatorOf(names); } + /** + * Returns a model evaluator which can be used to evaluate multiple functions in a model + * + * @param modelName the name of the model + * @param names the names of the outputs to evaluate, or none if all should be evaluated + * @throws IllegalArgumentException if the function or model is not present + */ + public MultiFunctionEvaluator multiEvaluatorOf(String modelName, String ... names) { + return requireModel(modelName).multiEvaluatorOf(names); + } + /** Returns the given model, or throws a IllegalArgumentException if it does not exist */ public Model requireModel(String name) { Model model = models.get(name); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java new file mode 100644 index 00000000000..53d470ecc19 --- /dev/null +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java @@ -0,0 +1,120 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * An evaluator which can be used to evaluate a model with multiple outputs. + * This will ensure that ONNX models are only evaluated once. + * + * @author lesters + */ +public class MultiFunctionEvaluator { + + private final List<FunctionEvaluator> functions; + private boolean evaluated = false; + + MultiFunctionEvaluator(List<FunctionEvaluator> functions) { + this.functions = functions; + } + + /** + * Binds the given variable referred in this expression to the given value. + * + * @param name the variable to bind + * @param value the value this becomes bound to + * @return this for chaining + */ + public MultiFunctionEvaluator bind(String name, Tensor value) { + if (evaluated) + throw new IllegalStateException("Cannot bind a new value in a used evaluator"); + for (FunctionEvaluator function : functions) { + if (function.function().argumentTypes().containsKey(name)) { + function.bind(name, value); // only bind input to the functions that need them + } + } + return this; + } + + /** + * Binds the given variable referred in this expression to the given value. + * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code> + * + * @param name the variable to bind + * @param value the value this becomes bound to + * @return this for chaining + */ + public MultiFunctionEvaluator bind(String name, double value) { + return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build()); + } + + public Map<String, Tensor> evaluate() { + for (FunctionEvaluator function : functions) { + function.checkArguments(); + } + + evaluateOnnxModels(); // evaluate each ONNX model only once + + Map<String, Tensor> results = new HashMap<>(); + for (FunctionEvaluator function : functions) { + results.put(function.function().getName(), function.evaluate()); + } + evaluated = true; + return results; + } + + /** + * Evaluate all ONNX models across all functions once and add the result + * back to the functions' context. + */ + private void evaluateOnnxModels() { + Set<OnnxModel> onnxModels = new HashSet<>(); + for (FunctionEvaluator function : functions) { + onnxModels.addAll(function.context().onnxModels().values()); + } + + for (OnnxModel onnxModel : onnxModels) { + + // Gather inputs from all functions. Inputs with the same name must have the same value. + Map<String, Tensor> inputs = new HashMap<>(); + for (FunctionEvaluator function : functions) { + for (OnnxModel functionModel : function.context().onnxModels().values()) { + if (functionModel.name().equals(onnxModel.name())) { + for (String inputName: onnxModel.inputs().keySet()) { + inputs.put(inputName, function.context().get(inputName).asTensor()); + } + } + } + } + + // Evaluate model once. + Map<String, Tensor> outputs = onnxModel.evaluate(inputs); + + // Add outputs back to the context of the functions that need them; they won't be recalculated. + for (FunctionEvaluator function : functions) { + for (Map.Entry<String, OnnxModel> entry : function.context().onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel functionModel = entry.getValue(); + if (functionModel.name().equals(onnxModel.name())) { + Tensor result = outputs.get(function.function().getName()); // Function name is output of model + function.context().put(onnxFeature, new TensorValue(result)); + } + } + } + + } + } + + public List<FunctionEvaluator> functions() { + return functions; + } + +} diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 19a9a1dccd5..06045b07f7c 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -50,6 +50,10 @@ class OnnxModel { return evaluator().evaluate(inputs, output); } + public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) { + return evaluator().evaluate(inputs); + } + private OnnxEvaluator evaluator() { if (evaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java index a15c35fe854..59ab378e43a 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -18,6 +18,7 @@ import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; /** @@ -44,6 +45,18 @@ public class OnnxEvaluatorTest { function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]")); assertEquals(5.0, function.evaluate().sum().asDouble(), delta); + MultiFunctionEvaluator evaluator = models.multiEvaluatorOf("add_mul"); + Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]"); + Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]"); + Map<String, Tensor> result = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, result.get("output1").sum().asDouble(), delta); + assertEquals(5.0, result.get("output2").sum().asDouble(), delta); + + evaluator = models.multiEvaluatorOf("add_mul", "output1"); + result = evaluator.bind("input1", input1).bind("input2", input2).evaluate(); + assertTrue("Result does not contain requested output", result.containsKey("output1")); + assertFalse("Result contains output that was not requested", result.containsKey("output2")); + function = models.evaluatorOf("one_layer"); function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]")); assertEquals(function.evaluate(), Tensor.from("tensor<float>(d0[2],d1[1]):[0.63931,0.67574]")); |