aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-22 12:43:27 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-22 14:50:50 +0000
commitfc5c9a366b06a3e04091be1e8f784be8bb82e1f5 (patch)
tree23f4a17f0a5afe6d1fc156fa480ccb9a007fb8b0 /model-evaluation
parent890e0ac9e795ca1c95e459f98a54593ac151051c (diff)
add configurable input/output mappings
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java117
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java9
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java18
3 files changed, 134 insertions, 10 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index ac66b1151f3..aa586a43d98 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -8,6 +8,9 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
/**
@@ -17,6 +20,40 @@ import java.util.Map;
*/
class OnnxModel implements AutoCloseable {
+ static class InputSpec {
+ String onnxName;
+ String source;
+ TensorType wantedType;
+ InputSpec(String name, String source, TensorType tType) {
+ this.onnxName = name;
+ this.source = source;
+ this.wantedType = tType;
+ }
+ InputSpec(String name, String source) { this(name, source, null); }
+ }
+
+ static class OutputSpec {
+ String onnxName;
+ String outputAs;
+ TensorType expectedType;
+ OutputSpec(String name, String as, TensorType tType) {
+ this.onnxName = name;
+ this.outputAs = as;
+ this.expectedType = tType;
+ }
+ OutputSpec(String name, String as) { this(name, as, null); }
+ }
+
+ final List<InputSpec> inputSpecs = new ArrayList<>();
+ final List<OutputSpec> outputSpecs = new ArrayList<>();
+
+ void addInputMapping(String onnxName, String source) {
+ inputSpecs.add(new InputSpec(onnxName, source));
+ }
+ void addOutputMapping(String onnxName, String outputAs) {
+ outputSpecs.add(new OutputSpec(onnxName, outputAs));
+ }
+
private final String name;
private final File modelFile;
private final OnnxEvaluatorOptions options;
@@ -38,19 +75,93 @@ class OnnxModel implements AutoCloseable {
public void load() {
if (referencedEvaluator == null) {
referencedEvaluator = cache.evaluatorOf(modelFile.getPath(), options);
+ fillInputTypes(evaluator().getInputInfo());
+ fillOutputTypes(evaluator().getOutputInfo());
+ }
+ }
+
+ void fillInputTypes(Map<String, TensorType> wantedTypes) {
+ if (inputSpecs.isEmpty()) {
+ for (var entry : wantedTypes.entrySet()) {
+ String name = entry.getKey();
+ TensorType tType = entry.getValue();
+ var spec = new InputSpec(name, name, tType);
+ inputSpecs.add(spec);
+ }
+ } else {
+ if (wantedTypes.size() != inputSpecs.size()) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": Mismatch between " + inputSpecs.size() +
+ " configured inputs and " +
+ wantedTypes.size() + " actual model inputs");
+ }
+ for (var spec : inputSpecs) {
+ TensorType tType = wantedTypes.get(spec.onnxName);
+ if (tType == null) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": No type in actual model for configured input "
+ + spec.onnxName);
+ }
+ spec.wantedType = tType;
+ }
+ }
+ }
+
+ void fillOutputTypes(Map<String, TensorType> outputTypes) {
+ if (outputSpecs.isEmpty()) {
+ for (var entry : outputTypes.entrySet()) {
+ String name = entry.getKey();
+ TensorType tType = entry.getValue();
+ var spec = new OutputSpec(name, name, tType);
+ outputSpecs.add(spec);
+ }
+ } else {
+ if (outputTypes.size() != outputSpecs.size()) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": Mismatch between " + outputSpecs.size() +
+ " configured outputs and " +
+ outputTypes.size() + " actual model outputs");
+ }
+ for (var spec : outputSpecs) {
+ TensorType tType = outputTypes.get(spec.onnxName);
+ if (tType == null) {
+ throw new IllegalArgumentException("Onnx model " + name() +
+ ": No type in actual model for configured output "
+ + spec.onnxName);
+ }
+ spec.expectedType = tType;
+ }
}
}
public Map<String, TensorType> inputs() {
- return evaluator().getInputInfo();
+ var map = new HashMap<String, TensorType>();
+ for (var spec : inputSpecs) {
+ map.put(spec.source, spec.wantedType);
+ }
+ return map;
}
public Map<String, TensorType> outputs() {
- return evaluator().getOutputInfo();
+ var map = new HashMap<String, TensorType>();
+ for (var spec : outputSpecs) {
+ map.put(spec.outputAs, spec.expectedType);
+ }
+ return map;
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
- return evaluator().evaluate(inputs, output);
+ var mapped = new HashMap<String, Tensor>();
+ for (var spec : inputSpecs) {
+ mapped.put(spec.onnxName, inputs.get(spec.source));
+ }
+ String onnxName = output;
+ for (var spec : outputSpecs) {
+ if (spec.outputAs.equals(output)) {
+ onnxName = spec.onnxName;
+ }
+ }
+ return evaluator().evaluate(mapped, onnxName);
}
private OnnxEvaluator evaluator() {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 2d91f86117e..098e6e7a1f6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -186,7 +186,14 @@ public class RankProfilesConfigImporter {
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
- return new OnnxModel(name, file, options, onnxEvaluatorCache);
+ var m = new OnnxModel(name, file, options, onnxEvaluatorCache);
+ for (var spec : onnxModelConfig.input()) {
+ m.addInputMapping(spec.name(), spec.source());
+ }
+ for (var spec : onnxModelConfig.output()) {
+ m.addOutputMapping(spec.name(), spec.as());
+ }
+ return m;
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 6c4dd886f4b..38215858366 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -27,16 +27,22 @@ class HandlerTester {
}
private static Predicate<String> matchString(String expected) {
return s -> {
- //System.out.println("Expected: " + expected);
- //System.out.println("Actual: " + s);
- return expected.equals(s);
+ boolean result = expected.equals(s);
+ if (!result) {
+ System.out.println("Expected: " + expected);
+ System.out.println("Actual: " + s);
+ }
+ return result;
};
}
private static Predicate<String> matchJsonString(String expected) {
return s -> {
- //System.out.println("Expected: " + expected);
- //System.out.println("Actual: " + s);
- return JSON.canonical(expected).equals(JSON.canonical(s));
+ boolean result = JSON.canonical(expected).equals(JSON.canonical(s));
+ if (!result) {
+ System.out.println("Expected: " + expected);
+ System.out.println("Actual: " + s);
+ }
+ return result;
};
}
public static Predicate<String> matchJson(String... expectedJson) {