aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-10 11:18:11 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-10 11:19:20 +0000
commit616b5d99ff8ed7f52b33f110941416180a8e8697 (patch)
tree6497200a9fdbbd5b28e9348156b16f649c57a853 /model-integration
parent6cafa7d885b5e0fe2ca7e33a786d0fa79d0e48ff (diff)
ensure outputs with names as promised by getOutputInfo()
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java68
-rw-r--r--model-integration/src/test/models/onnx/badnames.onnx34
3 files changed, 123 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 5cc7991d197..9961c24005c 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -37,6 +37,7 @@ public class OnnxEvaluator {
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
Map<String, OnnxTensor> onnxInputs = null;
try {
+ output = mapToInternalName(output);
onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
return TensorConverter.toVespaTensor(result.get(0));
@@ -57,7 +58,8 @@ public class OnnxEvaluator {
Map<String, Tensor> outputs = new HashMap<>();
try (OrtSession.Result result = session.run(onnxInputs)) {
for (Map.Entry<String, OnnxValue> output : result) {
- outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue()));
+ String mapped = TensorConverter.asValidName(output.getKey());
+ outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue()));
}
return outputs;
}
@@ -133,4 +135,22 @@ public class OnnxEvaluator {
}
}
+ private String mapToInternalName(String outputName) throws OrtException {
+ var info = session.getOutputInfo();
+ var internalNames = info.keySet();
+ for (String name : internalNames) {
+ if (name.equals(outputName)) {
+ return name;
+ }
+ }
+ for (String name : internalNames) {
+ String mapped = TensorConverter.asValidName(name);
+ if (mapped.equals(outputName)) {
+ return name;
+ }
+ }
+ // Probably will not work, but give the correct error from session.run
+ return outputName;
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 6266dcef174..83f355821e5 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -10,6 +10,7 @@ import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
/**
@@ -83,6 +84,73 @@ public class OnnxEvaluatorTest {
// assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
}
+ @Test
+ public void testNotIdentifiers() {
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx");
+ var inputInfo = evaluator.getInputInfo();
+ var outputInfo = evaluator.getOutputInfo();
+ for (var entry : inputInfo.entrySet()) {
+ System.out.println("wants input: " + entry.getKey() + " with type " + entry.getValue());
+ }
+ for (var entry : outputInfo.entrySet()) {
+ System.out.println("will produce output: " + entry.getKey() + " with type " + entry.getValue());
+ }
+
+ assertEquals(3, inputInfo.size());
+ assertTrue(inputInfo.containsKey("first_input"));
+ assertTrue(inputInfo.containsKey("second_input_0"));
+ assertTrue(inputInfo.containsKey("third_input"));
+
+ assertEquals(3, outputInfo.size());
+ assertTrue(outputInfo.containsKey("path_to_output_0"));
+ assertTrue(outputInfo.containsKey("path_to_output_1"));
+ assertTrue(outputInfo.containsKey("path_to_output_2"));
+
+ Map<String, Tensor> inputs = new HashMap<>();
+ inputs.put("first_input", Tensor.from("tensor(d0[2]):[2,3]"));
+ inputs.put("second_input_0", Tensor.from("tensor(d0[2]):[4,5]"));
+ inputs.put("third_input", Tensor.from("tensor(d0[2]):[6,7]"));
+
+ Tensor result;
+ result = evaluator.evaluate(inputs, "path_to_output_0");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ result = evaluator.evaluate(inputs, "path_to_output_1");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ result = evaluator.evaluate(inputs, "path_to_output_2");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ var allResults = evaluator.evaluate(inputs);
+ assertTrue(allResults != null);
+ for (var entry : allResults.entrySet()) {
+ System.out.println("produced output: " + entry.getKey() + " with type " + entry.getValue());
+ }
+ assertEquals(3, allResults.size());
+ assertTrue(allResults.containsKey("path_to_output_0"));
+ assertTrue(allResults.containsKey("path_to_output_1"));
+ assertTrue(allResults.containsKey("path_to_output_2"));
+
+ // we can also get output by onnx-internal name
+ result = evaluator.evaluate(inputs, "path/to/output:0");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ // we can also send input by onnx-internal name
+ inputs.remove("second_input_0");
+ inputs.put("second/input:0", Tensor.from("tensor(d0[2]):[8,9]"));
+ allResults = evaluator.evaluate(inputs);
+ assertTrue(allResults != null);
+ for (var entry : allResults.entrySet()) {
+ System.out.println("produced output: " + entry.getKey() + " with type " + entry.getValue());
+ }
+ assertEquals(3, allResults.size());
+ }
+
private void assertEvaluate(String model, String output, String... input) {
OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
diff --git a/model-integration/src/test/models/onnx/badnames.onnx b/model-integration/src/test/models/onnx/badnames.onnx
new file mode 100644
index 00000000000..f3898205c6a
--- /dev/null
+++ b/model-integration/src/test/models/onnx/badnames.onnx
@@ -0,0 +1,34 @@
+create_model.py:í
+4
+ first_input
+second/input:0path/to/output:0"Add
+4
+ third_input
+second/input:0path/to/output:1"Add
+;
+path/to/output:0
+path/to/output:1path/to/output:2"Addsimple_scoringZ
+ first_input
+
+
+Z
+second/input:0
+
+
+Z
+ third_input
+
+
+b
+path/to/output:0
+
+
+b
+path/to/output:1
+
+
+b
+path/to/output:2
+
+
+B \ No newline at end of file