summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java43
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java30
2 files changed, 57 insertions, 16 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index aa586a43d98..b86cf60318a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -48,9 +48,11 @@ class OnnxModel implements AutoCloseable {
final List<OutputSpec> outputSpecs = new ArrayList<>();
void addInputMapping(String onnxName, String source) {
+ assert(referencedEvaluator == null);
inputSpecs.add(new InputSpec(onnxName, source));
}
void addOutputMapping(String onnxName, String outputAs) {
+ assert(referencedEvaluator == null);
outputSpecs.add(new OutputSpec(onnxName, outputAs));
}
@@ -75,17 +77,18 @@ class OnnxModel implements AutoCloseable {
public void load() {
if (referencedEvaluator == null) {
referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
- fillInputTypes(evaluator().getInputInfo());
- fillOutputTypes(evaluator().getOutputInfo());
+ fillInputTypes(evaluator().getInputs());
+ fillOutputTypes(evaluator().getOutputs());
}
}
- void fillInputTypes(Map<String, TensorType> wantedTypes) {
+ void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> wantedTypes) {
if (inputSpecs.isEmpty()) {
for (var entry : wantedTypes.entrySet()) {
String name = entry.getKey();
- TensorType tType = entry.getValue();
- var spec = new InputSpec(name, name, tType);
+ String source = entry.getValue().id();
+ TensorType tType = entry.getValue().type();
+ var spec = new InputSpec(name, source, tType);
inputSpecs.add(spec);
}
} else {
@@ -96,23 +99,24 @@ class OnnxModel implements AutoCloseable {
wantedTypes.size() + " actual model inputs");
}
for (var spec : inputSpecs) {
- TensorType tType = wantedTypes.get(spec.onnxName);
- if (tType == null) {
+ var entry = wantedTypes.get(spec.onnxName);
+ if (entry == null) {
throw new IllegalArgumentException("Onnx model " + name() +
": No type in actual model for configured input "
+ spec.onnxName);
}
- spec.wantedType = tType;
+ spec.wantedType = entry.type();
}
}
}
- void fillOutputTypes(Map<String, TensorType> outputTypes) {
+ void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> outputTypes) {
if (outputSpecs.isEmpty()) {
for (var entry : outputTypes.entrySet()) {
String name = entry.getKey();
- TensorType tType = entry.getValue();
- var spec = new OutputSpec(name, name, tType);
+ String as = entry.getValue().id();
+ TensorType tType = entry.getValue().type();
+ var spec = new OutputSpec(name, as, tType);
outputSpecs.add(spec);
}
} else {
@@ -123,13 +127,13 @@ class OnnxModel implements AutoCloseable {
outputTypes.size() + " actual model outputs");
}
for (var spec : outputSpecs) {
- TensorType tType = outputTypes.get(spec.onnxName);
- if (tType == null) {
+ var entry = outputTypes.get(spec.onnxName);
+ if (entry == null) {
throw new IllegalArgumentException("Onnx model " + name() +
": No type in actual model for configured output "
+ spec.onnxName);
}
- spec.expectedType = tType;
+ spec.expectedType = entry.type();
}
}
}
@@ -153,14 +157,21 @@ class OnnxModel implements AutoCloseable {
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
var mapped = new HashMap<String, Tensor>();
for (var spec : inputSpecs) {
- mapped.put(spec.onnxName, inputs.get(spec.source));
+ Tensor val = inputs.get(spec.source);
+ if (val == null) {
+ throw new IllegalArgumentException("evaluate ONNX model " + name() + ": missing input from source " + spec.source);
+ }
+ mapped.put(spec.onnxName, val);
}
- String onnxName = output;
+ String onnxName = null;
for (var spec : outputSpecs) {
if (spec.outputAs.equals(output)) {
onnxName = spec.onnxName;
}
}
+ if (onnxName == null) {
+ throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output);
+ }
return evaluator().evaluate(mapped, onnxName);
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 66eb8caabd0..c2d97e37074 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -2,6 +2,7 @@
package ai.vespa.modelintegration.evaluator;
+import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
@@ -72,6 +73,35 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
+ public record IdAndType(String id, TensorType type) { }
+
+ private Map<String, IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) {
+ Map<String, IdAndType> result = new HashMap<>();
+ for (var info : infoMap.entrySet()) {
+ String name = info.getKey();
+ String ident = TensorConverter.asValidName(name);
+ TensorType t = TensorConverter.toVespaType(info.getValue().getInfo());
+ result.put(name, new IdAndType(ident, t));
+ }
+ return result;
+ }
+
+ public Map<String, IdAndType> getInputs() {
+ try {
+ return toSpecMap(session.getInputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Map<String, IdAndType> getOutputs() {
+ try {
+ return toSpecMap(session.getOutputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
public Map<String, TensorType> getInputInfo() {
try {
return TensorConverter.toVespaTypes(session.getInputInfo());