summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-02-22 21:30:52 +0100
committerGitHub <noreply@github.com>2023-02-22 21:30:52 +0100
commit69f149d4f91a2043f1d801afd89596fedacb69a2 (patch)
tree54bd737c31258dba1d92d5ca14581a705d83bbda
parent37a353eb3056f0f154add4d8787ad05da5bf6629 (diff)
parent0884752871f37850249362fd20d66d4ae765b8ec (diff)
Merge pull request #26146 from vespa-engine/arnej/add-mappings-for-onnx-model-evaluation
add configurable input/output mappings
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java128
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java9
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java30
4 files changed, 175 insertions, 10 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index ac66b1151f3..b86cf60318a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -8,6 +8,9 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
/**
@@ -17,6 +20,42 @@ import java.util.Map;
*/
class OnnxModel implements AutoCloseable {
+ static class InputSpec {
+ String onnxName;
+ String source;
+ TensorType wantedType;
+ InputSpec(String name, String source, TensorType tType) {
+ this.onnxName = name;
+ this.source = source;
+ this.wantedType = tType;
+ }
+ InputSpec(String name, String source) { this(name, source, null); }
+ }
+
+ static class OutputSpec {
+ String onnxName;
+ String outputAs;
+ TensorType expectedType;
+ OutputSpec(String name, String as, TensorType tType) {
+ this.onnxName = name;
+ this.outputAs = as;
+ this.expectedType = tType;
+ }
+ OutputSpec(String name, String as) { this(name, as, null); }
+ }
+
+ final List<InputSpec> inputSpecs = new ArrayList<>();
+ final List<OutputSpec> outputSpecs = new ArrayList<>();
+
+ void addInputMapping(String onnxName, String source) {
+ assert(referencedEvaluator == null);
+ inputSpecs.add(new InputSpec(onnxName, source));
+ }
+ void addOutputMapping(String onnxName, String outputAs) {
+ assert(referencedEvaluator == null);
+ outputSpecs.add(new OutputSpec(onnxName, outputAs));
+ }
+
private final String name;
private final File modelFile;
private final OnnxEvaluatorOptions options;
@@ -38,19 +77,102 @@ class OnnxModel implements AutoCloseable {
public void load() {
if (referencedEvaluator == null) {
referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
+ fillInputTypes(evaluator().getInputs());
+ fillOutputTypes(evaluator().getOutputs());
+ }
+ }
+
+ void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> wantedTypes) {
+ if (inputSpecs.isEmpty()) {
+ for (var entry : wantedTypes.entrySet()) {
+ String name = entry.getKey();
+ String source = entry.getValue().id();
+ TensorType tType = entry.getValue().type();
+ var spec = new InputSpec(name, source, tType);
+ inputSpecs.add(spec);
+ }
+ } else {
+ if (wantedTypes.size() != inputSpecs.size()) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": Mismatch between " + inputSpecs.size() +
+ " configured inputs and " +
+ wantedTypes.size() + " actual model inputs");
+ }
+ for (var spec : inputSpecs) {
+ var entry = wantedTypes.get(spec.onnxName);
+ if (entry == null) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": No type in actual model for configured input "
+ + spec.onnxName);
+ }
+ spec.wantedType = entry.type();
+ }
+ }
+ }
+
+ void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> outputTypes) {
+ if (outputSpecs.isEmpty()) {
+ for (var entry : outputTypes.entrySet()) {
+ String name = entry.getKey();
+ String as = entry.getValue().id();
+ TensorType tType = entry.getValue().type();
+ var spec = new OutputSpec(name, as, tType);
+ outputSpecs.add(spec);
+ }
+ } else {
+ if (outputTypes.size() != outputSpecs.size()) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": Mismatch between " + outputSpecs.size() +
+ " configured outputs and " +
+ outputTypes.size() + " actual model outputs");
+ }
+ for (var spec : outputSpecs) {
+ var entry = outputTypes.get(spec.onnxName);
+ if (entry == null) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": No type in actual model for configured output "
+ + spec.onnxName);
+ }
+ spec.expectedType = entry.type();
+ }
}
}
public Map<String, TensorType> inputs() {
- return evaluator().getInputInfo();
+ var map = new HashMap<String, TensorType>();
+ for (var spec : inputSpecs) {
+ map.put(spec.source, spec.wantedType);
+ }
+ return map;
}
public Map<String, TensorType> outputs() {
- return evaluator().getOutputInfo();
+ var map = new HashMap<String, TensorType>();
+ for (var spec : outputSpecs) {
+ map.put(spec.outputAs, spec.expectedType);
+ }
+ return map;
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
- return evaluator().evaluate(inputs, output);
+ var mapped = new HashMap<String, Tensor>();
+ for (var spec : inputSpecs) {
+ Tensor val = inputs.get(spec.source);
+ if (val == null) {
+ throw new IllegalArgumentException("evaluate ONNX model " + name() + ": missing input from source " + spec.source);
+ }
+ mapped.put(spec.onnxName, val);
+ }
+ String onnxName = null;
+ for (var spec : outputSpecs) {
+ if (spec.outputAs.equals(output)) {
+ onnxName = spec.onnxName;
+ }
+ }
+ if (onnxName == null) {
+ throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output);
+ }
+ return evaluator().evaluate(mapped, onnxName);
}
private OnnxEvaluator evaluator() {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 2d91f86117e..098e6e7a1f6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -186,7 +186,14 @@ public class RankProfilesConfigImporter {
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
- return new OnnxModel(name, file, options, onnxEvaluatorCache);
+ var m = new OnnxModel(name, file, options, onnxEvaluatorCache);
+ for (var spec : onnxModelConfig.input()) {
+ m.addInputMapping(spec.name(), spec.source());
+ }
+ for (var spec : onnxModelConfig.output()) {
+ m.addOutputMapping(spec.name(), spec.as());
+ }
+ return m;
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 6c4dd886f4b..38215858366 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -27,16 +27,22 @@ class HandlerTester {
}
private static Predicate<String> matchString(String expected) {
return s -> {
- //System.out.println("Expected: " + expected);
- //System.out.println("Actual: " + s);
- return expected.equals(s);
+ boolean result = expected.equals(s);
+ if (!result) {
+ System.out.println("Expected: " + expected);
+ System.out.println("Actual: " + s);
+ }
+ return result;
};
}
private static Predicate<String> matchJsonString(String expected) {
return s -> {
- //System.out.println("Expected: " + expected);
- //System.out.println("Actual: " + s);
- return JSON.canonical(expected).equals(JSON.canonical(s));
+ boolean result = JSON.canonical(expected).equals(JSON.canonical(s));
+ if (!result) {
+ System.out.println("Expected: " + expected);
+ System.out.println("Actual: " + s);
+ }
+ return result;
};
}
public static Predicate<String> matchJson(String... expectedJson) {
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 66eb8caabd0..c2d97e37074 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -2,6 +2,7 @@
package ai.vespa.modelintegration.evaluator;
+import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
@@ -72,6 +73,35 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
+ public record IdAndType(String id, TensorType type) { }
+
+ private Map<String, IdAndType> toSpecMap(Map<String, NodeInfo> infoMap) {
+ Map<String, IdAndType> result = new HashMap<>();
+ for (var info : infoMap.entrySet()) {
+ String name = info.getKey();
+ String ident = TensorConverter.asValidName(name);
+ TensorType t = TensorConverter.toVespaType(info.getValue().getInfo());
+ result.put(name, new IdAndType(ident, t));
+ }
+ return result;
+ }
+
+ public Map<String, IdAndType> getInputs() {
+ try {
+ return toSpecMap(session.getInputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Map<String, IdAndType> getOutputs() {
+ try {
+ return toSpecMap(session.getOutputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
public Map<String, TensorType> getInputInfo() {
try {
return TensorConverter.toVespaTypes(session.getInputInfo());