aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-10 11:18:11 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-10 11:19:20 +0000
commit616b5d99ff8ed7f52b33f110941416180a8e8697 (patch)
tree6497200a9fdbbd5b28e9348156b16f649c57a853 /model-integration/src/test
parent6cafa7d885b5e0fe2ca7e33a786d0fa79d0e48ff (diff)
ensure outputs with names as promised by getOutputInfo()
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java68
-rw-r--r--model-integration/src/test/models/onnx/badnames.onnx34
2 files changed, 102 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 6266dcef174..83f355821e5 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -10,6 +10,7 @@ import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
/**
@@ -83,6 +84,73 @@ public class OnnxEvaluatorTest {
// assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
}
+ @Test
+ public void testNotIdentifiers() {
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx");
+ var inputInfo = evaluator.getInputInfo();
+ var outputInfo = evaluator.getOutputInfo();
+ for (var entry : inputInfo.entrySet()) {
+ System.out.println("wants input: " + entry.getKey() + " with type " + entry.getValue());
+ }
+ for (var entry : outputInfo.entrySet()) {
+ System.out.println("will produce output: " + entry.getKey() + " with type " + entry.getValue());
+ }
+
+ assertEquals(3, inputInfo.size());
+ assertTrue(inputInfo.containsKey("first_input"));
+ assertTrue(inputInfo.containsKey("second_input_0"));
+ assertTrue(inputInfo.containsKey("third_input"));
+
+ assertEquals(3, outputInfo.size());
+ assertTrue(outputInfo.containsKey("path_to_output_0"));
+ assertTrue(outputInfo.containsKey("path_to_output_1"));
+ assertTrue(outputInfo.containsKey("path_to_output_2"));
+
+ Map<String, Tensor> inputs = new HashMap<>();
+ inputs.put("first_input", Tensor.from("tensor(d0[2]):[2,3]"));
+ inputs.put("second_input_0", Tensor.from("tensor(d0[2]):[4,5]"));
+ inputs.put("third_input", Tensor.from("tensor(d0[2]):[6,7]"));
+
+ Tensor result;
+ result = evaluator.evaluate(inputs, "path_to_output_0");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ result = evaluator.evaluate(inputs, "path_to_output_1");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ result = evaluator.evaluate(inputs, "path_to_output_2");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ var allResults = evaluator.evaluate(inputs);
+ assertTrue(allResults != null);
+ for (var entry : allResults.entrySet()) {
+ System.out.println("produced output: " + entry.getKey() + " with type " + entry.getValue());
+ }
+ assertEquals(3, allResults.size());
+ assertTrue(allResults.containsKey("path_to_output_0"));
+ assertTrue(allResults.containsKey("path_to_output_1"));
+ assertTrue(allResults.containsKey("path_to_output_2"));
+
+ // we can also get output by onnx-internal name
+ result = evaluator.evaluate(inputs, "path/to/output:0");
+ System.out.println("got result: " + result);
+ assertTrue(result != null);
+
+ // we can also send input by onnx-internal name
+ inputs.remove("second_input_0");
+ inputs.put("second/input:0", Tensor.from("tensor(d0[2]):[8,9]"));
+ allResults = evaluator.evaluate(inputs);
+ assertTrue(allResults != null);
+ for (var entry : allResults.entrySet()) {
+ System.out.println("produced output: " + entry.getKey() + " with type " + entry.getValue());
+ }
+ assertEquals(3, allResults.size());
+ }
+
private void assertEvaluate(String model, String output, String... input) {
OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
diff --git a/model-integration/src/test/models/onnx/badnames.onnx b/model-integration/src/test/models/onnx/badnames.onnx
new file mode 100644
index 00000000000..f3898205c6a
--- /dev/null
+++ b/model-integration/src/test/models/onnx/badnames.onnx
@@ -0,0 +1,34 @@
+create_model.py:í
+4
+ first_input
+second/input:0path/to/output:0"Add
+4
+ third_input
+second/input:0path/to/output:1"Add
+;
+path/to/output:0
+path/to/output:1path/to/output:2"Addsimple_scoringZ
+ first_input
+
+
+Z
+second/input:0
+
+
+Z
+ third_input
+
+
+b
+path/to/output:0
+
+
+b
+path/to/output:1
+
+
+b
+path/to/output:2
+
+
+B \ No newline at end of file