aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-11-22 15:41:12 +0100
committerLester Solbakken <lesters@oath.com>2021-11-22 15:41:12 +0100
commitbf026551a16d12d4b4e75949933d080f39a85eef (patch)
treef74bcec8d1c82b4c1f49f07e80c31b6750450167
parentea7d530c80515724f8384ded392da5b9a3ab3741 (diff)
Only evaluate ONNX models once in stateless model eval
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/mul.onnx17
-rwxr-xr-xconfig-model/src/test/cfg/application/stateless_eval/mul.py20
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java10
-rw-r--r--model-evaluation/abi-spec.json16
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java10
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java18
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java11
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java120
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java13
10 files changed, 221 insertions, 18 deletions
diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.onnx b/config-model/src/test/cfg/application/stateless_eval/mul.onnx
index 087e2c3427f..26411c96986 100644
--- a/config-model/src/test/cfg/application/stateless_eval/mul.onnx
+++ b/config-model/src/test/cfg/application/stateless_eval/mul.onnx
@@ -1,7 +1,10 @@
-mul.py:f
-
+mul.py:Ÿ
+
input1
-input2output"MulmulZ
+input2output1"Mul
+
+input1
+input2output2"AddmulZ
input1

@@ -9,8 +12,12 @@
input2

-b
-output
+b
+output1
+
+
+b
+output2

B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.py b/config-model/src/test/cfg/application/stateless_eval/mul.py
index 9fcb8612af9..6bbc4e23200 100755
--- a/config-model/src/test/cfg/application/stateless_eval/mul.py
+++ b/config-model/src/test/cfg/application/stateless_eval/mul.py
@@ -2,25 +2,31 @@
import onnx
from onnx import helper, TensorProto
-INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
-INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
-OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+INPUT1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1])
+OUTPUT2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1])
nodes = [
helper.make_node(
'Mul',
['input1', 'input2'],
- ['output'],
+ ['output1'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output2'],
),
]
graph_def = helper.make_graph(
nodes,
'mul',
[
- INPUT_1,
- INPUT_2
+ INPUT1,
+ INPUT2
],
- [OUTPUT],
+ [OUTPUT1, OUTPUT2],
)
model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
onnx.save(model_def, 'mul.onnx')
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
index 5630d3cc186..70c4cb942bc 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
@@ -3,9 +3,12 @@ package com.yahoo.vespa.model.container.ml;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.ModelsEvaluator;
+import ai.vespa.models.evaluation.MultiFunctionEvaluator;
import com.yahoo.tensor.Tensor;
import org.junit.Test;
+import java.util.Map;
+
import static org.junit.Assert.assertEquals;
/**
@@ -21,12 +24,17 @@ public class ModelsEvaluatorTest {
assertEquals(3, modelsEvaluator.models().size());
// ONNX model evaluation
- FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul");
+ FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul", "output1");
Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
+ MultiFunctionEvaluator eval = modelsEvaluator.multiEvaluatorOf("mul");
+ Map<String, Tensor> out = eval.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, out.get("output1").sum().asDouble(), 1e-9);
+ assertEquals(5.0, out.get("output2").sum().asDouble(), 1e-9);
+
// LightGBM model evaluation
FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression");
lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i");
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 6728d5cd9b4..3f23e7456ad 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -56,6 +56,7 @@
"public java.lang.String name()",
"public java.util.List functions()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])",
+ "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String[])",
"public java.lang.String toString()"
],
"fields": []
@@ -72,10 +73,25 @@
"public void <init>(java.util.Map)",
"public java.util.Map models()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])",
+ "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String, java.lang.String[])",
"public ai.vespa.models.evaluation.Model requireModel(java.lang.String)"
],
"fields": []
},
+ "ai.vespa.models.evaluation.MultiFunctionEvaluator": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)",
+ "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, double)",
+ "public java.util.Map evaluate()",
+ "public java.util.List functions()"
+ ],
+ "fields": []
+ },
"ai.vespa.models.evaluation.RankProfilesConfigImporter": {
"superClass": "java.lang.Object",
"interfaces": [],
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 6af33e29e62..aa13cb96845 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -101,14 +101,18 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- checkArgument(argument.getKey(), argument.getValue());
- }
+ checkArguments();
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
}
+ void checkArguments() {
+ for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
+ checkArgument(argument.getKey(), argument.getValue());
+ }
+ }
+
private void checkArgument(String name, TensorType type) {
if (context.isMissing(name))
throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 8af5f7bc499..84ab6e81840 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -6,9 +6,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -232,6 +230,22 @@ public class Model {
return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy());
}
+ /**
+ * Returns an evaluator which can be used to evaluate the given model in a single thread once.
+ *
+ * @param names The names identifying the outputs. If none are passed, evaluates all outputs.
+ * @throws IllegalArgumentException if the function is not present.
+ */
+ public MultiFunctionEvaluator multiEvaluatorOf(String ... names) {
+ List<FunctionEvaluator> evaluators;
+ if (names.length == 0) {
+ evaluators = functions.stream().map(this::evaluatorOf).collect(Collectors.toList());
+ } else {
+ evaluators = Arrays.stream(names).map(this::evaluatorOf).collect(Collectors.toList());
+ }
+ return new MultiFunctionEvaluator(evaluators);
+ }
+
private void throwUndeterminedFunction(String message) {
throw new IllegalArgumentException(message + ". Available functions: " +
functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 01427ca811a..bd00f5510c6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -60,6 +60,17 @@ public class ModelsEvaluator extends AbstractComponent {
return requireModel(modelName).evaluatorOf(names);
}
+ /**
+ * Returns a model evaluator which can be used to evaluate multiple functions in a model
+ *
+ * @param modelName the name of the model
+ * @param names the names of the outputs to evaluate, or none if all should be evaluated
+ * @throws IllegalArgumentException if the function or model is not present
+ */
+ public MultiFunctionEvaluator multiEvaluatorOf(String modelName, String ... names) {
+ return requireModel(modelName).multiEvaluatorOf(names);
+ }
+
/** Returns the given model, or throws a IllegalArgumentException if it does not exist */
public Model requireModel(String name) {
Model model = models.get(name);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java
new file mode 100644
index 00000000000..53d470ecc19
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java
@@ -0,0 +1,120 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An evaluator which can be used to evaluate a model with multiple outputs.
+ * This will ensure that ONNX models are only evaluated once.
+ *
+ * @author lesters
+ */
+public class MultiFunctionEvaluator {
+
+ private final List<FunctionEvaluator> functions;
+ private boolean evaluated = false;
+
+ MultiFunctionEvaluator(List<FunctionEvaluator> functions) {
+ this.functions = functions;
+ }
+
+ /**
+ * Binds the given variable referred in this expression to the given value.
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public MultiFunctionEvaluator bind(String name, Tensor value) {
+ if (evaluated)
+ throw new IllegalStateException("Cannot bind a new value in a used evaluator");
+ for (FunctionEvaluator function : functions) {
+ if (function.function().argumentTypes().containsKey(name)) {
+ function.bind(name, value); // only bind input to the functions that need them
+ }
+ }
+ return this;
+ }
+
+ /**
+ * Binds the given variable referred in this expression to the given value.
+ * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code>
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public MultiFunctionEvaluator bind(String name, double value) {
+ return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build());
+ }
+
+ public Map<String, Tensor> evaluate() {
+ for (FunctionEvaluator function : functions) {
+ function.checkArguments();
+ }
+
+ evaluateOnnxModels(); // evaluate each ONNX model only once
+
+ Map<String, Tensor> results = new HashMap<>();
+ for (FunctionEvaluator function : functions) {
+ results.put(function.function().getName(), function.evaluate());
+ }
+ evaluated = true;
+ return results;
+ }
+
+ /**
+ * Evaluate all ONNX models across all functions once and add the result
+ * back to the functions' context.
+ */
+ private void evaluateOnnxModels() {
+ Set<OnnxModel> onnxModels = new HashSet<>();
+ for (FunctionEvaluator function : functions) {
+ onnxModels.addAll(function.context().onnxModels().values());
+ }
+
+ for (OnnxModel onnxModel : onnxModels) {
+
+ // Gather inputs from all functions. Inputs with the same name must have the same value.
+ Map<String, Tensor> inputs = new HashMap<>();
+ for (FunctionEvaluator function : functions) {
+ for (OnnxModel functionModel : function.context().onnxModels().values()) {
+ if (functionModel.name().equals(onnxModel.name())) {
+ for (String inputName: onnxModel.inputs().keySet()) {
+ inputs.put(inputName, function.context().get(inputName).asTensor());
+ }
+ }
+ }
+ }
+
+ // Evaluate model once.
+ Map<String, Tensor> outputs = onnxModel.evaluate(inputs);
+
+ // Add outputs back to the context of the functions that need them; they won't be recalculated.
+ for (FunctionEvaluator function : functions) {
+ for (Map.Entry<String, OnnxModel> entry : function.context().onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel functionModel = entry.getValue();
+ if (functionModel.name().equals(onnxModel.name())) {
+ Tensor result = outputs.get(function.function().getName()); // Function name is output of model
+ function.context().put(onnxFeature, new TensorValue(result));
+ }
+ }
+ }
+
+ }
+ }
+
+ public List<FunctionEvaluator> functions() {
+ return functions;
+ }
+
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index 19a9a1dccd5..06045b07f7c 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -50,6 +50,10 @@ class OnnxModel {
return evaluator().evaluate(inputs, output);
}
+ public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
+ return evaluator().evaluate(inputs);
+ }
+
private OnnxEvaluator evaluator() {
if (evaluator == null) {
throw new IllegalStateException("ONNX model has not been loaded.");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
index a15c35fe854..59ab378e43a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -18,6 +18,7 @@ import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/**
@@ -44,6 +45,18 @@ public class OnnxEvaluatorTest {
function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
assertEquals(5.0, function.evaluate().sum().asDouble(), delta);
+ MultiFunctionEvaluator evaluator = models.multiEvaluatorOf("add_mul");
+ Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
+ Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
+ Map<String, Tensor> result = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, result.get("output1").sum().asDouble(), delta);
+ assertEquals(5.0, result.get("output2").sum().asDouble(), delta);
+
+ evaluator = models.multiEvaluatorOf("add_mul", "output1");
+ result = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
+ assertTrue("Result does not contain requested output", result.containsKey("output1"));
+ assertFalse("Result contains output that was not requested", result.containsKey("output2"));
+
function = models.evaluatorOf("one_layer");
function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]"));
assertEquals(function.evaluate(), Tensor.from("tensor<float>(d0[2],d1[1]):[0.63931,0.67574]"));