diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 5cc7991d197..9961c24005c 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -37,6 +37,7 @@ public class OnnxEvaluator { public Tensor evaluate(Map<String, Tensor> inputs, String output) { Map<String, OnnxTensor> onnxInputs = null; try { + output = mapToInternalName(output); onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { return TensorConverter.toVespaTensor(result.get(0)); @@ -57,7 +58,8 @@ public class OnnxEvaluator { Map<String, Tensor> outputs = new HashMap<>(); try (OrtSession.Result result = session.run(onnxInputs)) { for (Map.Entry<String, OnnxValue> output : result) { - outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue())); + String mapped = TensorConverter.asValidName(output.getKey()); + outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue())); } return outputs; } @@ -133,4 +135,22 @@ public class OnnxEvaluator { } } + private String mapToInternalName(String outputName) throws OrtException { + var info = session.getOutputInfo(); + var internalNames = info.keySet(); + for (String name : internalNames) { + if (name.equals(outputName)) { + return name; + } + } + for (String name : internalNames) { + String mapped = TensorConverter.asValidName(name); + if (mapped.equals(outputName)) { + return name; + } + } + // Probably will not work, but give the correct error from session.run + return outputName; + } + } |