summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-08 13:08:24 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-08 13:08:24 +0000
commit9dd9026441ca39df01dbdb456e3e38e402d28a5e (patch)
tree429144ca9e9af230ebe4a1d589c73699d9b766e9 /model-evaluation
parent686a300d8dee0ebb4257dc3b609399533c6dd40b (diff)
add ExpressionNode computing an output from an ONNX model
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java102
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java19
2 files changed, 121 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java
new file mode 100644
index 00000000000..a50d9e36d74
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java
@@ -0,0 +1,102 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * make it possible to evaluate an ONNX model anywhere in the ranking expression tree
+ */
+class OnnxExpressionNode extends CompositeNode {
+ private final OnnxModel model;
+ private final String onnxOutputName;
+ private final TensorType expectedType;
+ private final String outputAs;
+ private final List<String> modelInputs = new ArrayList<>();
+ private final List<ExpressionNode> inputRefs = new ArrayList<>();
+
+ OnnxExpressionNode(OnnxModel model, String onnxOutputName, TensorType expectedType, String outputAs) {
+ this.model = model;
+ this.onnxOutputName = onnxOutputName;
+ this.expectedType = expectedType;
+ this.outputAs = outputAs;
+ for (var input : model.inputSpecs) {
+ modelInputs.add(input.onnxName);
+ var optRef = parseOnnxInput(input.source);
+ if (optRef.isEmpty()) {
+ throw new IllegalArgumentException("Bad input source for ONNX model " + model.name() + ": '" + input + "'");
+ }
+ var ref = optRef.get();
+ inputRefs.add(new ReferenceNode(ref));
+ }
+ }
+
+ static Optional<Reference> parseOnnxInput(String input) {
+ var optRef = Reference.simple(input);
+ if (optRef.isPresent()) {
+ return optRef;
+ }
+ try {
+ var ref = Reference.fromIdentifier(input);
+ return Optional.of(ref);
+ } catch (Exception e) {
+ // fallthrough
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public List<ExpressionNode> children() { return List.copyOf(inputRefs); }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (inputRefs.size() != children.size()) {
+ throw new IllegalArgumentException("bad setChildren");
+ }
+ inputRefs.clear();
+ inputRefs.addAll(children);
+ return this;
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Map<String, Tensor> inputs = new HashMap<>();
+ for (int i = 0; i < modelInputs.size(); i++) {
+ Value inputValue = inputRefs.get(i).evaluate(context);
+ inputs.put(modelInputs.get(i), inputValue.asTensor());
+ }
+ return new TensorValue(model.evaluate(inputs, onnxOutputName));
+ }
+
+ @Override
+ public TensorType type(TypeContext<Reference> context) { return expectedType; }
+
+ @Override
+ public int hashCode() { return Objects.hash("OnnxExpressionNode", model.name(), onnxOutputName); }
+
+ @Override
+ public StringBuilder toString(StringBuilder b, SerializationContext context, Deque<String> path, CompositeNode parent) {
+ b.append("onnx_expression_node(").append(model.name()).append(")");
+ if (outputAs != null && ! outputAs.equals("")) {
+ b.append(".").append(outputAs);
+ }
+ return b;
+ }
+}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index cf97c20e881..59febf7cdbf 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -4,6 +4,9 @@ package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -64,6 +67,7 @@ class OnnxModel implements AutoCloseable {
private final OnnxRuntime onnx;
private OnnxEvaluator evaluator;
+ private final Map<String, ExpressionNode> exprPerOutput = new HashMap<>();
OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) {
this.name = name;
@@ -81,6 +85,7 @@ class OnnxModel implements AutoCloseable {
evaluator = onnx.evaluatorOf(modelFile.getPath(), options);
fillInputTypes(evaluator().getInputs());
fillOutputTypes(evaluator().getOutputs());
+ fillOutputExpressions();
}
}
@@ -156,6 +161,20 @@ class OnnxModel implements AutoCloseable {
return map;
}
+ void fillOutputExpressions() {
+ for (var spec : outputSpecs) {
+ var node = new OnnxExpressionNode(this, spec.onnxName, spec.expectedType, spec.outputAs);
+ exprPerOutput.put(spec.outputAs, node);
+ }
+ }
+
+ ExpressionNode getExpressionForOutput(String outputName) {
+ if (outputName == null && exprPerOutput.size() == 1) {
+ return exprPerOutput.values().iterator().next();
+ }
+ return exprPerOutput.get(outputName);
+ }
+
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
var mapped = new HashMap<String, Tensor>();
for (var spec : inputSpecs) {