From 4f3ca3ec859b23dad2801623465a495c29fbc436 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Thu, 30 Mar 2023 16:08:24 +0200 Subject: Support loading ONNX models through byte array Rewrite OnnxRuntimeTest to test through it's public API --- .../evaluator/OnnxEvaluatorTest.java | 14 ++++ .../evaluator/OnnxRuntimeTest.java | 95 +++++++++++++++++----- 2 files changed, 88 insertions(+), 21 deletions(-) (limited to 'model-integration/src/test') diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index b2d76baa566..5a367ef83e4 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -6,6 +6,9 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; @@ -156,6 +159,17 @@ public class OnnxEvaluatorTest { assertEquals(3, allResults.size()); } + @Test + public void testLoadModelFromBytes() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + var model = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx")); + var evaluator = runtime.evaluatorOf(model); + assertEquals(3, evaluator.getInputs().size()); + assertEquals(1, evaluator.getOutputs().size()); + evaluator.close(); + } + private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) { OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model); Map inputs = new HashMap<>(); diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java index 81b1237e770..fdbd4fa4e5c 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java @@ -2,16 +2,18 @@ package ai.vespa.modelintegration.evaluator; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; /** * @author bjorncs @@ -19,30 +21,81 @@ import static org.mockito.Mockito.verify; class OnnxRuntimeTest { @Test - void reuses_sessions_while_active() throws OrtException { - var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class)); - var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false); - assertSame(session1.instance(), session2.instance()); - assertNotSame(session1.instance(), session3.instance()); + void reuses_sessions_while_active() { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + OnnxRuntime runtime = new OnnxRuntime(); + String model1 = "src/test/models/onnx/simple/simple.onnx"; + var evaluator1 = runtime.evaluatorOf(model1); + var evaluator2 = runtime.evaluatorOf(model1); + String model2 = "src/test/models/onnx/simple/matmul.onnx"; + var evaluator3 = runtime.evaluatorOf(model2); + assertSameSession(evaluator1, evaluator2); + assertNotSameSession(evaluator1, evaluator3); assertEquals(2, runtime.sessionsCached()); - session1.close(); - session2.close(); + evaluator1.close(); + evaluator2.close(); assertEquals(1, runtime.sessionsCached()); - verify(session1.instance()).close(); - verify(session3.instance(), never()).close(); + assertClosed(evaluator1); + assertNotClosed(evaluator3); - session3.close(); + evaluator3.close(); assertEquals(0, runtime.sessionsCached()); - verify(session3.instance()).close(); + assertClosed(evaluator3); - var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - assertNotSame(session1.instance(), session4.instance()); + var session4 = runtime.evaluatorOf(model1); + assertNotSameSession(evaluator1, session4); assertEquals(1, runtime.sessionsCached()); session4.close(); assertEquals(0, runtime.sessionsCached()); - verify(session4.instance()).close(); + assertClosed(session4); + } + + @Test + void loads_model_from_byte_array() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + byte[] bytes = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx")); + var evaluator1 = runtime.evaluatorOf(bytes); + var evaluator2 = runtime.evaluatorOf(bytes); + assertEquals(3, evaluator1.getInputs().size()); + assertEquals(1, runtime.sessionsCached()); + assertSameSession(evaluator1, evaluator2); + evaluator2.close(); + evaluator1.close(); + assertEquals(0, runtime.sessionsCached()); + assertClosed(evaluator1); + } + + @Test + void loading_same_model_from_bytes_and_file_resolve_to_same_instance() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + String modelPath = "src/test/models/onnx/simple/simple.onnx"; + byte[] bytes = Files.readAllBytes(Paths.get(modelPath)); + try (var evaluator1 = runtime.evaluatorOf(bytes); + var evaluator2 = runtime.evaluatorOf(modelPath)) { + assertSameSession(evaluator1, evaluator2); + assertEquals(1, runtime.sessionsCached()); + } + } + + private static void assertClosed(OnnxEvaluator evaluator) { assertTrue(isClosed(evaluator), "Session is not closed"); } + private static void assertNotClosed(OnnxEvaluator evaluator) { assertFalse(isClosed(evaluator), "Session is closed"); } + private static void assertSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) { + assertSame(evaluator1.ortSession(), evaluator2.ortSession()); + } + private static void assertNotSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) { + assertNotSame(evaluator1.ortSession(), evaluator2.ortSession()); + } + + private static boolean isClosed(OnnxEvaluator evaluator) { + try { + evaluator.getInputs(); + return false; + } catch (IllegalStateException e) { + assertEquals("Asking for inputs from a closed OrtSession.", e.getMessage()); + return true; + } } } \ No newline at end of file -- cgit v1.2.3