diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-02-22 21:30:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-22 21:30:52 +0100 |
commit | 69f149d4f91a2043f1d801afd89596fedacb69a2 (patch) | |
tree | 54bd737c31258dba1d92d5ca14581a705d83bbda | |
parent | 37a353eb3056f0f154add4d8787ad05da5bf6629 (diff) | |
parent | 0884752871f37850249362fd20d66d4ae765b8ec (diff) |
Merge pull request #26146 from vespa-engine/arnej/add-mappings-for-onnx-model-evaluation
add configurable input/output mappings
4 files changed, 175 insertions, 10 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index ac66b1151f3..b86cf60318a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -8,6 +8,9 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.io.File; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -17,6 +20,42 @@ import java.util.Map; */ class OnnxModel implements AutoCloseable { + static class InputSpec { + String onnxName; + String source; + TensorType wantedType; + InputSpec(String name, String source, TensorType tType) { + this.onnxName = name; + this.source = source; + this.wantedType = tType; + } + InputSpec(String name, String source) { this(name, source, null); } + } + + static class OutputSpec { + String onnxName; + String outputAs; + TensorType expectedType; + OutputSpec(String name, String as, TensorType tType) { + this.onnxName = name; + this.outputAs = as; + this.expectedType = tType; + } + OutputSpec(String name, String as) { this(name, as, null); } + } + + final List<InputSpec> inputSpecs = new ArrayList<>(); + final List<OutputSpec> outputSpecs = new ArrayList<>(); + + void addInputMapping(String onnxName, String source) { + assert(referencedEvaluator == null); + inputSpecs.add(new InputSpec(onnxName, source)); + } + void addOutputMapping(String onnxName, String outputAs) { + assert(referencedEvaluator == null); + outputSpecs.add(new OutputSpec(onnxName, outputAs)); + } + private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; @@ -38,19 +77,102 @@ class OnnxModel implements AutoCloseable { public void load() { if (referencedEvaluator == null) { referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); + fillInputTypes(evaluator().getInputs()); + fillOutputTypes(evaluator().getOutputs()); + } + } + + void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> wantedTypes) { + if (inputSpecs.isEmpty()) { + for (var entry : wantedTypes.entrySet()) { + String name = entry.getKey(); + String source = entry.getValue().id(); + TensorType tType = entry.getValue().type(); + var spec = new InputSpec(name, source, tType); + inputSpecs.add(spec); + } + } else { + if (wantedTypes.size() != inputSpecs.size()) { + throw new IllegalArgumentException("Onnx model " + name() + + ": Mismatch between " + inputSpecs.size() + + " configured inputs and " + + wantedTypes.size() + " actual model inputs"); + } + for (var spec : inputSpecs) { + var entry = wantedTypes.get(spec.onnxName); + if (entry == null) { + throw new IllegalArgumentException("Onnx model " + name() + + ": No type in actual model for configured input " + + spec.onnxName); + } + spec.wantedType = entry.type(); + } + } + } + + void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> outputTypes) { + if (outputSpecs.isEmpty()) { + for (var entry : outputTypes.entrySet()) { + String name = entry.getKey(); + String as = entry.getValue().id(); + TensorType tType = entry.getValue().type(); + var spec = new OutputSpec(name, as, tType); + outputSpecs.add(spec); + } + } else { + if (outputTypes.size() != outputSpecs.size()) { + throw new IllegalArgumentException("Onnx model " + name() + + ": Mismatch between " + outputSpecs.size() + + " configured outputs and " + + outputTypes.size() + " actual model outputs"); + } + for (var spec : outputSpecs) { + var entry = outputTypes.get(spec.onnxName); + if (entry == null) { + throw new IllegalArgumentException("Onnx model " + name() + + ": No type in actual model for configured output " + + spec.onnxName); + } + spec.expectedType = entry.type(); + } } } public Map<String, TensorType> inputs() { - return evaluator().getInputInfo(); + var map = new HashMap<String, TensorType>(); + for (var spec : inputSpecs) { + map.put(spec.source, spec.wantedType); + } + return map; } public Map<String, TensorType> outputs() { - return evaluator().getOutputInfo(); + var map = new HashMap<String, TensorType>(); + for (var spec : outputSpecs) { + map.put(spec.outputAs, spec.expectedType); + } + return map; } public Tensor evaluate(Map<String, Tensor> inputs, String output) { - return evaluator().evaluate(inputs, output); + var mapped = new HashMap<String, Tensor>(); + for (var spec : inputSpecs) { + Tensor val = inputs.get(spec.source); + if (val == null) { + throw new IllegalArgumentException("evaluate ONNX model " + name() + ": missing input from source " + spec.source); + } + mapped.put(spec.onnxName, val); + } + String onnxName = null; + for (var spec : outputSpecs) { + if (spec.outputAs.equals(output)) { + onnxName = spec.onnxName; + } + } + if (onnxName == null) { + throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output); + } + return evaluator().evaluate(mapped, onnxName); } private OnnxEvaluator evaluator() { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 2d91f86117e..098e6e7a1f6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -186,7 +186,14 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - return new OnnxModel(name, file, options, onnxEvaluatorCache); + var m = new OnnxModel(name, file, options, onnxEvaluatorCache); + for (var spec : onnxModelConfig.input()) { + m.addInputMapping(spec.name(), spec.source()); + } + for (var spec : onnxModelConfig.output()) { + m.addOutputMapping(spec.name(), spec.as()); + } + return m; } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index 6c4dd886f4b..38215858366 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -27,16 +27,22 @@ class HandlerTester { } private static Predicate<String> matchString(String expected) { return s -> { - //System.out.println("Expected: " + expected); - //System.out.println("Actual: " + s); - return expected.equals(s); + boolean result = expected.equals(s); + if (!result) { + System.out.println("Expected: " + expected); + System.out.println("Actual: " + s); + } + return result; }; } private static Predicate<String> matchJsonString(String expected) { return s -> { - //System.out.println("Expected: " + expected); - //System.out.println("Actual: " + s); - return JSON.canonical(expected).equals(JSON.canonical(s)); + boolean result = JSON.canonical(expected).equals(JSON.canonical(s)); + if (!result) { + System.out.println("Expected: " + expected); + System.out.println("Actual: " + s); + } + return result; }; } public static Predicate<String> matchJson(String... expectedJson) { diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 66eb8caabd0..c2d97e37074 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -2,6 +2,7 @@ package ai.vespa.modelintegration.evaluator; +import ai.onnxruntime.NodeInfo; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; import ai.onnxruntime.OrtEnvironment; @@ -72,6 +73,35 @@ public class OnnxEvaluator implements AutoCloseable { } } + public record IdAndType(String id, TensorType type) { } + + private Map<String, IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) { + Map<String, IdAndType> result = new HashMap<>(); + for (var info : infoMap.entrySet()) { + String name = info.getKey(); + String ident = TensorConverter.asValidName(name); + TensorType t = TensorConverter.toVespaType(info.getValue().getInfo()); + result.put(name, new IdAndType(ident, t)); + } + return result; + } + + public Map<String, IdAndType> getInputs() { + try { + return toSpecMap(session.getInputInfo()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + + public Map<String, IdAndType> getOutputs() { + try { + return toSpecMap(session.getOutputInfo()); + } catch (OrtException e) { + throw new RuntimeException("ONNX Runtime exception", e); + } + } + public Map<String, TensorType> getInputInfo() { try { return TensorConverter.toVespaTypes(session.getInputInfo()); |