diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-11 21:08:04 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-11 21:08:04 +0000 |
commit | 94e2606685038688052dd5a0819d60ee6a27c3ce (patch) | |
tree | 1a6ae270dabfe55fd3b11dfbcb312c18263c4a85 /model-evaluation/src/main/java/ai | |
parent | 30cdbc3cad0279e6d6c9bc337926598046ab09a6 (diff) |
pick up specified name of output from onnx model
Diffstat (limited to 'model-evaluation/src/main/java/ai')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 3fada4c8b6d..0ebf9880105 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -122,13 +122,19 @@ public class FunctionEvaluator { private void evaluateOnnxModels() { for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) { String onnxFeature = entry.getKey(); + String outputName = function.getName(); // Function name is output of model (sometimes) + int idx = onnxFeature.indexOf(")."); + if (idx > 0 && idx + 2 < onnxFeature.length()) { + // explicitly specified as onnx(modelname).outputname ; pick the last part + outputName = onnxFeature.substring(idx+2); + } OnnxModel onnxModel = entry.getValue(); if (context.get(onnxFeature).equals(context.defaultValue())) { Map<String, Tensor> inputs = new HashMap<>(); for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) { inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); } - Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model + Tensor result = onnxModel.evaluate(inputs, outputName); context.put(onnxFeature, new TensorValue(result)); } } |