aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java19
1 files changed, 19 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index cf97c20e881..59febf7cdbf 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -4,6 +4,9 @@ package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -64,6 +67,7 @@ class OnnxModel implements AutoCloseable {
private final OnnxRuntime onnx;
private OnnxEvaluator evaluator;
+ private final Map<String, ExpressionNode> exprPerOutput = new HashMap<>();
OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) {
this.name = name;
@@ -81,6 +85,7 @@ class OnnxModel implements AutoCloseable {
evaluator = onnx.evaluatorOf(modelFile.getPath(), options);
fillInputTypes(evaluator().getInputs());
fillOutputTypes(evaluator().getOutputs());
+ fillOutputExpressions();
}
}
@@ -156,6 +161,20 @@ class OnnxModel implements AutoCloseable {
return map;
}
+ void fillOutputExpressions() {
+ for (var spec : outputSpecs) {
+ var node = new OnnxExpressionNode(this, spec.onnxName, spec.expectedType, spec.outputAs);
+ exprPerOutput.put(spec.outputAs, node);
+ }
+ }
+
+ ExpressionNode getExpressionForOutput(String outputName) {
+ if (outputName == null && exprPerOutput.size() == 1) {
+ return exprPerOutput.values().iterator().next();
+ }
+ return exprPerOutput.get(outputName);
+ }
+
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
var mapped = new HashMap<String, Tensor>();
for (var spec : inputSpecs) {