diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index cf97c20e881..59febf7cdbf 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -4,6 +4,9 @@ package ai.vespa.models.evaluation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -64,6 +67,7 @@ class OnnxModel implements AutoCloseable { private final OnnxRuntime onnx; private OnnxEvaluator evaluator; + private final Map<String, ExpressionNode> exprPerOutput = new HashMap<>(); OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) { this.name = name; @@ -81,6 +85,7 @@ class OnnxModel implements AutoCloseable { evaluator = onnx.evaluatorOf(modelFile.getPath(), options); fillInputTypes(evaluator().getInputs()); fillOutputTypes(evaluator().getOutputs()); + fillOutputExpressions(); } } @@ -156,6 +161,20 @@ class OnnxModel implements AutoCloseable { return map; } + void fillOutputExpressions() { + for (var spec : outputSpecs) { + var node = new OnnxExpressionNode(this, spec.onnxName, spec.expectedType, spec.outputAs); + exprPerOutput.put(spec.outputAs, node); + } + } + + ExpressionNode getExpressionForOutput(String outputName) { + if (outputName == null && exprPerOutput.size() == 1) { + return exprPerOutput.values().iterator().next(); + } + return exprPerOutput.get(outputName); + } + public Tensor evaluate(Map<String, Tensor> inputs, String output) { var mapped = new HashMap<String, Tensor>(); for (var spec : inputSpecs) { |