aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-11 21:08:04 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-11 21:08:04 +0000
commit94e2606685038688052dd5a0819d60ee6a27c3ce (patch)
tree1a6ae270dabfe55fd3b11dfbcb312c18263c4a85 /model-evaluation
parent30cdbc3cad0279e6d6c9bc337926598046ab09a6 (diff)
pick up specified name of output from onnx model
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java8
1 files changed, 7 insertions, 1 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 3fada4c8b6d..0ebf9880105 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -122,13 +122,19 @@ public class FunctionEvaluator {
private void evaluateOnnxModels() {
for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) {
String onnxFeature = entry.getKey();
+ String outputName = function.getName(); // Function name is output of model (sometimes)
+ int idx = onnxFeature.indexOf(").");
+ if (idx > 0 && idx + 2 < onnxFeature.length()) {
+ // explicitly specified as onnx(modelname).outputname ; pick the last part
+ outputName = onnxFeature.substring(idx+2);
+ }
OnnxModel onnxModel = entry.getValue();
if (context.get(onnxFeature).equals(context.defaultValue())) {
Map<String, Tensor> inputs = new HashMap<>();
for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
}
- Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model
+ Tensor result = onnxModel.evaluate(inputs, outputName);
context.put(onnxFeature, new TensorValue(result));
}
}