From 616b5d99ff8ed7f52b33f110941416180a8e8697 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Fri, 10 Feb 2023 11:18:11 +0000 Subject: ensure outputs with names as promised by getOutputInfo() --- .../modelintegration/evaluator/OnnxEvaluator.java | 22 ++++++- .../evaluator/OnnxEvaluatorTest.java | 68 ++++++++++++++++++++++ .../src/test/models/onnx/badnames.onnx | 34 +++++++++++ 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 model-integration/src/test/models/onnx/badnames.onnx (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 5cc7991d197..9961c24005c 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -37,6 +37,7 @@ public class OnnxEvaluator { public Tensor evaluate(Map inputs, String output) { Map onnxInputs = null; try { + output = mapToInternalName(output); onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session); try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) { return TensorConverter.toVespaTensor(result.get(0)); @@ -57,7 +58,8 @@ public class OnnxEvaluator { Map outputs = new HashMap<>(); try (OrtSession.Result result = session.run(onnxInputs)) { for (Map.Entry output : result) { - outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue())); + String mapped = TensorConverter.asValidName(output.getKey()); + outputs.put(mapped, TensorConverter.toVespaTensor(output.getValue())); } return outputs; } @@ -133,4 +135,22 @@ public class OnnxEvaluator { } } + private String mapToInternalName(String outputName) throws OrtException { + var info = session.getOutputInfo(); + var internalNames = info.keySet(); + for (String name : internalNames) { + if (name.equals(outputName)) { + return name; + } + } + for (String name : internalNames) { + String mapped = TensorConverter.asValidName(name); + if (mapped.equals(outputName)) { + return name; + } + } + // Probably will not work, but give the correct error from session.run + return outputName; + } + } diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 6266dcef174..83f355821e5 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -10,6 +10,7 @@ import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; /** @@ -83,6 +84,73 @@ public class OnnxEvaluatorTest { // assertEvaluate("cast_bfloat16_float.onnx", "tensor(d0[1]):[1]", "tensor(d0[1]):[1]"); } + @Test + public void testNotIdentifiers() { + assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx"); + var inputInfo = evaluator.getInputInfo(); + var outputInfo = evaluator.getOutputInfo(); + for (var entry : inputInfo.entrySet()) { + System.out.println("wants input: " + entry.getKey() + " with type " + entry.getValue()); + } + for (var entry : outputInfo.entrySet()) { + System.out.println("will produce output: " + entry.getKey() + " with type " + entry.getValue()); + } + + assertEquals(3, inputInfo.size()); + assertTrue(inputInfo.containsKey("first_input")); + assertTrue(inputInfo.containsKey("second_input_0")); + assertTrue(inputInfo.containsKey("third_input")); + + assertEquals(3, outputInfo.size()); + assertTrue(outputInfo.containsKey("path_to_output_0")); + assertTrue(outputInfo.containsKey("path_to_output_1")); + assertTrue(outputInfo.containsKey("path_to_output_2")); + + Map inputs = new HashMap<>(); + inputs.put("first_input", Tensor.from("tensor(d0[2]):[2,3]")); + inputs.put("second_input_0", Tensor.from("tensor(d0[2]):[4,5]")); + inputs.put("third_input", Tensor.from("tensor(d0[2]):[6,7]")); + + Tensor result; + result = evaluator.evaluate(inputs, "path_to_output_0"); + System.out.println("got result: " + result); + assertTrue(result != null); + + result = evaluator.evaluate(inputs, "path_to_output_1"); + System.out.println("got result: " + result); + assertTrue(result != null); + + result = evaluator.evaluate(inputs, "path_to_output_2"); + System.out.println("got result: " + result); + assertTrue(result != null); + + var allResults = evaluator.evaluate(inputs); + assertTrue(allResults != null); + for (var entry : allResults.entrySet()) { + System.out.println("produced output: " + entry.getKey() + " with type " + entry.getValue()); + } + assertEquals(3, allResults.size()); + assertTrue(allResults.containsKey("path_to_output_0")); + assertTrue(allResults.containsKey("path_to_output_1")); + assertTrue(allResults.containsKey("path_to_output_2")); + + // we can also get output by onnx-internal name + result = evaluator.evaluate(inputs, "path/to/output:0"); + System.out.println("got result: " + result); + assertTrue(result != null); + + // we can also send input by onnx-internal name + inputs.remove("second_input_0"); + inputs.put("second/input:0", Tensor.from("tensor(d0[2]):[8,9]")); + allResults = evaluator.evaluate(inputs); + assertTrue(allResults != null); + for (var entry : allResults.entrySet()) { + System.out.println("produced output: " + entry.getKey() + " with type " + entry.getValue()); + } + assertEquals(3, allResults.size()); + } + private void assertEvaluate(String model, String output, String... input) { OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model); Map inputs = new HashMap<>(); diff --git a/model-integration/src/test/models/onnx/badnames.onnx b/model-integration/src/test/models/onnx/badnames.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/model-integration/src/test/models/onnx/badnames.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 + first_input +second/input:0path/to/output:0"Add +4 + third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ + first_input + + +Z +second/input:0 + + +Z + third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B \ No newline at end of file -- cgit v1.2.3