diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-22 19:24:26 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-22 19:24:55 +0000 |
commit | 0884752871f37850249362fd20d66d4ae765b8ec (patch) | |
tree | 3758df43cc6c74a74207d95e21e9f636be2609b7 /model-integration | |
parent | fc5c9a366b06a3e04091be1e8f784be8bb82e1f5 (diff) |
handle non-identifier onnx input/output names: instead of the conflicting
ad-hoc code in OnnxEvaluator, do it as part of general input/output
mapping in OnnxModel.
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 66eb8caabd0..c2d97e37074 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -2,6 +2,7 @@ package ai.vespa.modelintegration.evaluator; +import ai.onnxruntime.NodeInfo; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; import ai.onnxruntime.OrtEnvironment; @@ -72,6 +73,35 @@ public class OnnxEvaluator implements AutoCloseable { } } + public record IdAndType(String id, TensorType type) { } + + private Map<String, IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) { + Map<String, IdAndType> result = new HashMap<>(); + for (var info : infoMap.entrySet()) { + String name = info.getKey(); + String ident = TensorConverter.asValidName(name); + TensorType t = TensorConverter.toVespaType(info.getValue().getInfo()); + result.put(name, new IdAndType(ident, t)); + } + return result; + } + + public Map<String, IdAndType> getInputs() { + try { + return toSpecMap(session.getInputInfo()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + public Map<String, IdAndType> getOutputs() { + try { + return toSpecMap(session.getOutputInfo()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + public Map<String, TensorType> getInputInfo() { try { return TensorConverter.toVespaTypes(session.getInputInfo()); |