diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-22 12:43:27 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-22 14:50:50 +0000 |
commit | fc5c9a366b06a3e04091be1e8f784be8bb82e1f5 (patch) | |
tree | 23f4a17f0a5afe6d1fc156fa480ccb9a007fb8b0 /model-evaluation | |
parent | 890e0ac9e795ca1c95e459f98a54593ac151051c (diff) |
add configurable input/output mappings
Diffstat (limited to 'model-evaluation')
3 files changed, 134 insertions, 10 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index ac66b1151f3..aa586a43d98 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -8,6 +8,9 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.io.File; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -17,6 +20,40 @@ import java.util.Map; */ class OnnxModel implements AutoCloseable { + static class InputSpec { + String onnxName; + String source; + TensorType wantedType; + InputSpec(String name, String source, TensorType tType) { + this.onnxName = name; + this.source = source; + this.wantedType = tType; + } + InputSpec(String name, String source) { this(name, source, null); } + } + + static class OutputSpec { + String onnxName; + String outputAs; + TensorType expectedType; + OutputSpec(String name, String as, TensorType tType) { + this.onnxName = name; + this.outputAs = as; + this.expectedType = tType; + } + OutputSpec(String name, String as) { this(name, as, null); } + } + + final List<InputSpec> inputSpecs = new ArrayList<>(); + final List<OutputSpec> outputSpecs = new ArrayList<>(); + + void addInputMapping(String onnxName, String source) { + inputSpecs.add(new InputSpec(onnxName, source)); + } + void addOutputMapping(String onnxName, String outputAs) { + outputSpecs.add(new OutputSpec(onnxName, outputAs)); + } + private final String name; private final File modelFile; private final OnnxEvaluatorOptions options; @@ -38,19 +75,93 @@ class OnnxModel implements AutoCloseable { public void load() { if (referencedEvaluator == null) { referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options); + fillInputTypes(evaluator().getInputInfo()); + fillOutputTypes(evaluator().getOutputInfo()); + } + } + + void fillInputTypes(Map<String, TensorType> wantedTypes) { + if (inputSpecs.isEmpty()) { + for (var entry : wantedTypes.entrySet()) { + String name = entry.getKey(); + TensorType tType = entry.getValue(); + var spec = new InputSpec(name, name, tType); + inputSpecs.add(spec); + } + } else { + if (wantedTypes.size() != inputSpecs.size()) { + throw new IllegalArgumentException("Onnx model " + name() + + ": Mismatch between " + inputSpecs.size() + + " configured inputs and " + + wantedTypes.size() + " actual model inputs"); + } + for (var spec : inputSpecs) { + TensorType tType = wantedTypes.get(spec.onnxName); + if (tType == null) { + throw new IllegalArgumentException("Onnx model " + name() + + ": No type in actual model for configured input " + + spec.onnxName); + } + spec.wantedType = tType; + } + } + } + + void fillOutputTypes(Map<String, TensorType> outputTypes) { + if (outputSpecs.isEmpty()) { + for (var entry : outputTypes.entrySet()) { + String name = entry.getKey(); + TensorType tType = entry.getValue(); + var spec = new OutputSpec(name, name, tType); + outputSpecs.add(spec); + } + } else { + if (outputTypes.size() != outputSpecs.size()) { + throw new IllegalArgumentException("Onnx model " + name() + + ": Mismatch between " + outputSpecs.size() + + " configured outputs and " + + outputTypes.size() + " actual model outputs"); + } + for (var spec : outputSpecs) { + TensorType tType = outputTypes.get(spec.onnxName); + if (tType == null) { + throw new IllegalArgumentException("Onnx model " + name() + + ": No type in actual model for configured output " + + spec.onnxName); + } + spec.expectedType = tType; + } } } public Map<String, TensorType> inputs() { - return evaluator().getInputInfo(); + var map = new HashMap<String, TensorType>(); + for (var spec : inputSpecs) { + map.put(spec.source, spec.wantedType); + } + return map; } public Map<String, TensorType> outputs() { - return evaluator().getOutputInfo(); + var map = new HashMap<String, TensorType>(); + for (var spec : outputSpecs) { + map.put(spec.outputAs, spec.expectedType); + } + return map; } public Tensor evaluate(Map<String, Tensor> inputs, String output) { - return evaluator().evaluate(inputs, output); + var mapped = new HashMap<String, Tensor>(); + for (var spec : inputSpecs) { + mapped.put(spec.onnxName, inputs.get(spec.source)); + } + String onnxName = output; + for (var spec : outputSpecs) { + if (spec.outputAs.equals(output)) { + onnxName = spec.onnxName; + } + } + return evaluator().evaluate(mapped, onnxName); } private OnnxEvaluator evaluator() { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 2d91f86117e..098e6e7a1f6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -186,7 +186,14 @@ public class RankProfilesConfigImporter { options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); - return new OnnxModel(name, file, options, onnxEvaluatorCache); + var m = new OnnxModel(name, file, options, onnxEvaluatorCache); + for (var spec : onnxModelConfig.input()) { + m.addInputMapping(spec.name(), spec.source()); + } + for (var spec : onnxModelConfig.output()) { + m.addOutputMapping(spec.name(), spec.as()); + } + return m; } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index 6c4dd886f4b..38215858366 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -27,16 +27,22 @@ class HandlerTester { } private static Predicate<String> matchString(String expected) { return s -> { - //System.out.println("Expected: " + expected); - //System.out.println("Actual: " + s); - return expected.equals(s); + boolean result = expected.equals(s); + if (!result) { + System.out.println("Expected: " + expected); + System.out.println("Actual: " + s); + } + return result; }; } private static Predicate<String> matchJsonString(String expected) { return s -> { - //System.out.println("Expected: " + expected); - //System.out.println("Actual: " + s); - return JSON.canonical(expected).equals(JSON.canonical(s)); + boolean result = JSON.canonical(expected).equals(JSON.canonical(s)); + if (!result) { + System.out.println("Expected: " + expected); + System.out.println("Actual: " + s); + } + return result; }; } public static Predicate<String> matchJson(String... expectedJson) { |