summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-22 19:24:26 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-22 19:24:55 +0000
commit0884752871f37850249362fd20d66d4ae765b8ec (patch)
tree3758df43cc6c74a74207d95e21e9f636be2609b7 /model-evaluation
parentfc5c9a366b06a3e04091be1e8f784be8bb82e1f5 (diff)
handle non-identifier onnx input/output names: instead of the conflicting
ad-hoc code in OnnxEvaluator, do it as part of general input/output mapping in OnnxModel.
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java43
1 files changed, 27 insertions, 16 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index aa586a43d98..b86cf60318a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -48,9 +48,11 @@ class OnnxModel implements AutoCloseable {
final List<OutputSpec> outputSpecs = new ArrayList<>();
void addInputMapping(String onnxName, String source) {
+ assert(referencedEvaluator == null);
inputSpecs.add(new InputSpec(onnxName, source));
}
void addOutputMapping(String onnxName, String outputAs) {
+ assert(referencedEvaluator == null);
outputSpecs.add(new OutputSpec(onnxName, outputAs));
}
@@ -75,17 +77,18 @@ class OnnxModel implements AutoCloseable {
public void load() {
if (referencedEvaluator == null) {
referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
- fillInputTypes(evaluator().getInputInfo());
- fillOutputTypes(evaluator().getOutputInfo());
+ fillInputTypes(evaluator().getInputs());
+ fillOutputTypes(evaluator().getOutputs());
}
}
- void fillInputTypes(Map<String, TensorType> wantedTypes) {
+ void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> wantedTypes) {
if (inputSpecs.isEmpty()) {
for (var entry : wantedTypes.entrySet()) {
String name = entry.getKey();
- TensorType tType = entry.getValue();
- var spec = new InputSpec(name, name, tType);
+ String source = entry.getValue().id();
+ TensorType tType = entry.getValue().type();
+ var spec = new InputSpec(name, source, tType);
inputSpecs.add(spec);
}
} else {
@@ -96,23 +99,24 @@ class OnnxModel implements AutoCloseable {
wantedTypes.size() + " actual model inputs");
}
for (var spec : inputSpecs) {
- TensorType tType = wantedTypes.get(spec.onnxName);
- if (tType == null) {
+ var entry = wantedTypes.get(spec.onnxName);
+ if (entry == null) {
throw new IllegalArgumentException("Onnx model " + name() +
": No type in actual model for configured input "
+ spec.onnxName);
}
- spec.wantedType = tType;
+ spec.wantedType = entry.type();
}
}
}
- void fillOutputTypes(Map<String, TensorType> outputTypes) {
+ void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> outputTypes) {
if (outputSpecs.isEmpty()) {
for (var entry : outputTypes.entrySet()) {
String name = entry.getKey();
- TensorType tType = entry.getValue();
- var spec = new OutputSpec(name, name, tType);
+ String as = entry.getValue().id();
+ TensorType tType = entry.getValue().type();
+ var spec = new OutputSpec(name, as, tType);
outputSpecs.add(spec);
}
} else {
@@ -123,13 +127,13 @@ class OnnxModel implements AutoCloseable {
outputTypes.size() + " actual model outputs");
}
for (var spec : outputSpecs) {
- TensorType tType = outputTypes.get(spec.onnxName);
- if (tType == null) {
+ var entry = outputTypes.get(spec.onnxName);
+ if (entry == null) {
throw new IllegalArgumentException("Onnx model " + name() +
": No type in actual model for configured output "
+ spec.onnxName);
}
- spec.expectedType = tType;
+ spec.expectedType = entry.type();
}
}
}
@@ -153,14 +157,21 @@ class OnnxModel implements AutoCloseable {
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
var mapped = new HashMap<String, Tensor>();
for (var spec : inputSpecs) {
- mapped.put(spec.onnxName, inputs.get(spec.source));
+ Tensor val = inputs.get(spec.source);
+ if (val == null) {
+ throw new IllegalArgumentException("evaluate ONNX model " + name() + ": missing input from source " + spec.source);
+ }
+ mapped.put(spec.onnxName, val);
}
- String onnxName = output;
+ String onnxName = null;
for (var spec : outputSpecs) {
if (spec.outputAs.equals(output)) {
onnxName = spec.onnxName;
}
}
+ if (onnxName == null) {
+ throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output);
+ }
return evaluator().evaluate(mapped, onnxName);
}