aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-22 19:24:26 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-22 19:24:55 +0000
commit0884752871f37850249362fd20d66d4ae765b8ec (patch)
tree3758df43cc6c74a74207d95e21e9f636be2609b7 /model-integration
parentfc5c9a366b06a3e04091be1e8f784be8bb82e1f5 (diff)
handle non-identifier onnx input/output names: instead of the conflicting
ad-hoc code in OnnxEvaluator, do it as part of general input/output mapping in OnnxModel.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java30
1 files changed, 30 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 66eb8caabd0..c2d97e37074 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -2,6 +2,7 @@
package ai.vespa.modelintegration.evaluator;
+import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
@@ -72,6 +73,35 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
+ public record IdAndType(String id, TensorType type) { }
+
+ private Map<String, IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) {
+ Map<String, IdAndType> result = new HashMap<>();
+ for (var info : infoMap.entrySet()) {
+ String name = info.getKey();
+ String ident = TensorConverter.asValidName(name);
+ TensorType t = TensorConverter.toVespaType(info.getValue().getInfo());
+ result.put(name, new IdAndType(ident, t));
+ }
+ return result;
+ }
+
+ public Map<String, IdAndType> getInputs() {
+ try {
+ return toSpecMap(session.getInputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Map<String, IdAndType> getOutputs() {
+ try {
+ return toSpecMap(session.getOutputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
public Map<String, TensorType> getInputInfo() {
try {
return TensorConverter.toVespaTypes(session.getInputInfo());