diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-22 19:24:26 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-22 19:24:55 +0000 |
commit | 0884752871f37850249362fd20d66d4ae765b8ec (patch) | |
tree | 3758df43cc6c74a74207d95e21e9f636be2609b7 /model-evaluation | |
parent | fc5c9a366b06a3e04091be1e8f784be8bb82e1f5 (diff) |
handle non-identifier onnx input/output names: instead of the conflicting
ad-hoc code in OnnxEvaluator, do it as part of general input/output
mapping in OnnxModel.
Diffstat (limited to 'model-evaluation')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index aa586a43d98..b86cf60318a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -48,9 +48,11 @@ class OnnxModel implements AutoCloseable { final List<OutputSpec> outputSpecs = new ArrayList<>(); void addInputMapping(String onnxName, String source) { + assert(referencedEvaluator == null); inputSpecs.add(new InputSpec(onnxName, source)); } void addOutputMapping(String onnxName, String outputAs) { + assert(referencedEvaluator == null); outputSpecs.add(new OutputSpec(onnxName, outputAs)); } @@ -75,17 +77,18 @@ class OnnxModel implements AutoCloseable { public void load() { if (referencedEvaluator == null) { referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); - fillInputTypes(evaluator().getInputInfo()); - fillOutputTypes(evaluator().getOutputInfo()); + fillInputTypes(evaluator().getInputs()); + fillOutputTypes(evaluator().getOutputs()); } } - void fillInputTypes(Map<String, TensorType> wantedTypes) { + void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> wantedTypes) { if (inputSpecs.isEmpty()) { for (var entry : wantedTypes.entrySet()) { String name = entry.getKey(); - TensorType tType = entry.getValue(); - var spec = new InputSpec(name, name, tType); + String source = entry.getValue().id(); + TensorType tType = entry.getValue().type(); + var spec = new InputSpec(name, source, tType); inputSpecs.add(spec); } } else { @@ -96,23 +99,24 @@ class OnnxModel implements AutoCloseable { wantedTypes.size() + " actual model inputs"); } for (var spec : inputSpecs) { - TensorType tType = wantedTypes.get(spec.onnxName); - if (tType == null) { + var entry = wantedTypes.get(spec.onnxName); + if (entry == null) { throw new IllegalArgumentException("Onnx model " + name() + ": No type in actual model for configured input " + spec.onnxName); } - spec.wantedType = tType; + spec.wantedType = entry.type(); } } } - void fillOutputTypes(Map<String, TensorType> outputTypes) { + void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> outputTypes) { if (outputSpecs.isEmpty()) { for (var entry : outputTypes.entrySet()) { String name = entry.getKey(); - TensorType tType = entry.getValue(); - var spec = new OutputSpec(name, name, tType); + String as = entry.getValue().id(); + TensorType tType = entry.getValue().type(); + var spec = new OutputSpec(name, as, tType); outputSpecs.add(spec); } } else { @@ -123,13 +127,13 @@ class OnnxModel implements AutoCloseable { outputTypes.size() + " actual model outputs"); } for (var spec : outputSpecs) { - TensorType tType = outputTypes.get(spec.onnxName); - if (tType == null) { + var entry = outputTypes.get(spec.onnxName); + if (entry == null) { throw new IllegalArgumentException("Onnx model " + name() + ": No type in actual model for configured output " + spec.onnxName); } - spec.expectedType = tType; + spec.expectedType = entry.type(); } } } @@ -153,14 +157,21 @@ class OnnxModel implements AutoCloseable { public Tensor evaluate(Map<String, Tensor> inputs, String output) { var mapped = new HashMap<String, Tensor>(); for (var spec : inputSpecs) { - mapped.put(spec.onnxName, inputs.get(spec.source)); + Tensor val = inputs.get(spec.source); + if (val == null) { + throw new IllegalArgumentException("evaluate ONNX model " + name() + ": missing input from source " + spec.source); + } + mapped.put(spec.onnxName, val); } - String onnxName = output; + String onnxName = null; for (var spec : outputSpecs) { if (spec.outputAs.equals(output)) { onnxName = spec.onnxName; } } + if (onnxName == null) { + throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output); + } return evaluator().evaluate(mapped, onnxName); } |