aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java22
1 files changed, 21 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 5cc7991d197..9961c24005c 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -37,6 +37,7 @@ public class OnnxEvaluator {
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
Map<String, OnnxTensor> onnxInputs = null;
try {
+ output = mapToInternalName(output);
onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
return TensorConverter.toVespaTensor(result.get(0));
@@ -57,7 +58,8 @@ public class OnnxEvaluator {
Map<String, Tensor> outputs = new HashMap<>();
try (OrtSession.Result result = session.run(onnxInputs)) {
for (Map.Entry<String, OnnxValue> output : result) {
- outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue()));
+ String mapped = TensorConverter.asValidName(output.getKey());
+ outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue()));
}
return outputs;
}
@@ -133,4 +135,22 @@ public class OnnxEvaluator {
}
}
+ private String mapToInternalName(String outputName) throws OrtException {
+ var info = session.getOutputInfo();
+ var internalNames = info.keySet();
+ for (String name : internalNames) {
+ if (name.equals(outputName)) {
+ return name;
+ }
+ }
+ for (String name : internalNames) {
+ String mapped = TensorConverter.asValidName(name);
+ if (mapped.equals(outputName)) {
+ return name;
+ }
+ }
+ // Probably will not work, but give the correct error from session.run
+ return outputName;
+ }
+
}