diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-03-30 16:08:24 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-03-30 16:08:31 +0200 |
commit | 4f3ca3ec859b23dad2801623465a495c29fbc436 (patch) | |
tree | debec19a2e10a9bab43b833f40d6bb42f928b220 /model-integration | |
parent | 5f641cbe5a558550b787945cea9ee4e20a3a659a (diff) |
Support loading ONNX models through byte array
Rewrite OnnxRuntimeTest to test through it's public API
Diffstat (limited to 'model-integration')
4 files changed, 162 insertions, 36 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 7cdc27b6d63..02fa7b68dc4 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -7,6 +7,7 @@ import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; +import ai.vespa.modelintegration.evaluator.OnnxRuntime.ModelPathOrData; import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -28,7 +29,11 @@ public class OnnxEvaluator implements AutoCloseable { private final ReferencedOrtSession session; OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) { - session = createSession(modelPath, runtime, options, true); + session = createSession(ModelPathOrData.of(modelPath), runtime, options, true); + } + + OnnxEvaluator(byte[] data, OnnxEvaluatorOptions options, OnnxRuntime runtime) { + session = createSession(ModelPathOrData.of(data), runtime, options, true); } public Tensor evaluate(Map<String, Tensor> inputs, String output) { @@ -125,19 +130,20 @@ public class OnnxEvaluator implements AutoCloseable { } } - private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { + private static ReferencedOrtSession createSession( + ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu()); + return runtime.acquireSession(model, options, tryCuda && options.requestingGpu()); } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { - throw new IllegalArgumentException("No such file: " + modelPath); + throw new IllegalArgumentException("No such file: " + model.path().get()); } if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { // Failed in CUDA native code, but GPU device is optional, so we can proceed without it - return createSession(modelPath, runtime, options, false); + return createSession(model, runtime, options, false); } if (isCudaError(e)) { throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e); @@ -146,6 +152,9 @@ public class OnnxEvaluator implements AutoCloseable { } } + // For unit testing + OrtSession ortSession() { return session.instance(); } + private String mapToInternalName(String outputName) throws OrtException { var info = session.instance().getOutputInfo(); var internalNames = info.keySet(); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java index 42830041c02..ece1db55c1e 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java @@ -10,9 +10,15 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.jdisc.ResourceReference; import com.yahoo.jdisc.refcount.DebugReferencesWithStack; import com.yahoo.jdisc.refcount.References; +import net.jpountz.xxhash.XXHashFactory; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; @@ -26,14 +32,22 @@ import static com.yahoo.yolean.Exceptions.throwUnchecked; public class OnnxRuntime extends AbstractComponent { // For unit testing - @FunctionalInterface interface OrtSessionFactory { + interface OrtSessionFactory { OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException; + OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException; } private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName()); private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment(); - private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts); + private static final OrtSessionFactory defaultFactory = new OrtSessionFactory() { + @Override public OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException { + return ortEnvironment().createSession(path, opts); + } + @Override public OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException { + return ortEnvironment().createSession(data, opts); + } + }; private final Object monitor = new Object(); private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>(); @@ -43,6 +57,14 @@ public class OnnxRuntime extends AbstractComponent { OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; } + public OnnxEvaluator evaluatorOf(byte[] model) { + return new OnnxEvaluator(model, null, this); + } + + public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) { + return new OnnxEvaluator(model, options, this); + } + public OnnxEvaluator evaluatorOf(String modelPath) { return new OnnxEvaluator(modelPath, null, this); } @@ -105,8 +127,8 @@ public class OnnxRuntime extends AbstractComponent { }; } - ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException { - var sessionId = new OrtSessionId(modelPath, options, loadCuda); + ReferencedOrtSession acquireSession(ModelPathOrData model, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException { + var sessionId = new OrtSessionId(calculateModelHash(model), options, loadCuda); synchronized (monitor) { var sharedSession = sessions.get(sessionId); if (sharedSession != null) { @@ -114,8 +136,9 @@ public class OnnxRuntime extends AbstractComponent { } } + var opts = options.getOptions(loadCuda); // Note: identical models loaded simultaneously will result in duplicate session instances - var session = factory.create(modelPath, options.getOptions(loadCuda)); + var session = model.path().isPresent() ? factory.create(model.path().get(), opts) : factory.create(model.data().get(), opts); log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session))); var sharedSession = new SharedOrtSession(sessionId, session); @@ -125,25 +148,52 @@ public class OnnxRuntime extends AbstractComponent { return referencedSession; } + private static long calculateModelHash(ModelPathOrData model) { + if (model.path().isPresent()) { + try (var hasher = XXHashFactory.fastestInstance().newStreamingHash64(0); + var in = Files.newInputStream(Paths.get(model.path().get()))) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + hasher.update(buffer, 0, bytesRead); + } + return hasher.getValue(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } else { + var data = model.data().get(); + return XXHashFactory.fastestInstance().hash64().hash(data, 0, data.length, 0); + } + } + int sessionsCached() { synchronized(monitor) { return sessions.size(); } } - public static class ReferencedOrtSession implements AutoCloseable { + static class ReferencedOrtSession implements AutoCloseable { private final OrtSession instance; private final ResourceReference ref; - public ReferencedOrtSession(OrtSession instance, ResourceReference ref) { + ReferencedOrtSession(OrtSession instance, ResourceReference ref) { this.instance = instance; this.ref = ref; } - public OrtSession instance() { return instance; } + OrtSession instance() { return instance; } @Override public void close() { ref.close(); } } + record ModelPathOrData(Optional<String> path, Optional<byte[]> data) { + static ModelPathOrData of(String path) { return new ModelPathOrData(Optional.of(path), Optional.empty()); } + static ModelPathOrData of(byte[] data) { return new ModelPathOrData(Optional.empty(), Optional.of(data)); } + ModelPathOrData { + if (path.isEmpty() == data.isEmpty()) throw new IllegalArgumentException("Either path or data must be non-empty"); + } + } + // Assumes options are never modified after being stored in `onnxSessions` - record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {} + private record OrtSessionId(long modelHash, OnnxEvaluatorOptions options, boolean loadCuda) {} - record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {} + private record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {} private class SharedOrtSession { private final OrtSessionId id; diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index b2d76baa566..5a367ef83e4 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -6,6 +6,9 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; @@ -156,6 +159,17 @@ public class OnnxEvaluatorTest { assertEquals(3, allResults.size()); } + @Test + public void testLoadModelFromBytes() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + var model = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx")); + var evaluator = runtime.evaluatorOf(model); + assertEquals(3, evaluator.getInputs().size()); + assertEquals(1, evaluator.getOutputs().size()); + evaluator.close(); + } + private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) { OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model); Map<String, Tensor> inputs = new HashMap<>(); diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java index 81b1237e770..fdbd4fa4e5c 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java @@ -2,16 +2,18 @@ package ai.vespa.modelintegration.evaluator; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; /** * @author bjorncs @@ -19,30 +21,81 @@ import static org.mockito.Mockito.verify; class OnnxRuntimeTest { @Test - void reuses_sessions_while_active() throws OrtException { - var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class)); - var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false); - assertSame(session1.instance(), session2.instance()); - assertNotSame(session1.instance(), session3.instance()); + void reuses_sessions_while_active() { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + OnnxRuntime runtime = new OnnxRuntime(); + String model1 = "src/test/models/onnx/simple/simple.onnx"; + var evaluator1 = runtime.evaluatorOf(model1); + var evaluator2 = runtime.evaluatorOf(model1); + String model2 = "src/test/models/onnx/simple/matmul.onnx"; + var evaluator3 = runtime.evaluatorOf(model2); + assertSameSession(evaluator1, evaluator2); + assertNotSameSession(evaluator1, evaluator3); assertEquals(2, runtime.sessionsCached()); - session1.close(); - session2.close(); + evaluator1.close(); + evaluator2.close(); assertEquals(1, runtime.sessionsCached()); - verify(session1.instance()).close(); - verify(session3.instance(), never()).close(); + assertClosed(evaluator1); + assertNotClosed(evaluator3); - session3.close(); + evaluator3.close(); assertEquals(0, runtime.sessionsCached()); - verify(session3.instance()).close(); + assertClosed(evaluator3); - var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - assertNotSame(session1.instance(), session4.instance()); + var session4 = runtime.evaluatorOf(model1); + assertNotSameSession(evaluator1, session4); assertEquals(1, runtime.sessionsCached()); session4.close(); assertEquals(0, runtime.sessionsCached()); - verify(session4.instance()).close(); + assertClosed(session4); + } + + @Test + void loads_model_from_byte_array() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + byte[] bytes = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx")); + var evaluator1 = runtime.evaluatorOf(bytes); + var evaluator2 = runtime.evaluatorOf(bytes); + assertEquals(3, evaluator1.getInputs().size()); + assertEquals(1, runtime.sessionsCached()); + assertSameSession(evaluator1, evaluator2); + evaluator2.close(); + evaluator1.close(); + assertEquals(0, runtime.sessionsCached()); + assertClosed(evaluator1); + } + + @Test + void loading_same_model_from_bytes_and_file_resolve_to_same_instance() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + String modelPath = "src/test/models/onnx/simple/simple.onnx"; + byte[] bytes = Files.readAllBytes(Paths.get(modelPath)); + try (var evaluator1 = runtime.evaluatorOf(bytes); + var evaluator2 = runtime.evaluatorOf(modelPath)) { + assertSameSession(evaluator1, evaluator2); + assertEquals(1, runtime.sessionsCached()); + } + } + + private static void assertClosed(OnnxEvaluator evaluator) { assertTrue(isClosed(evaluator), "Session is not closed"); } + private static void assertNotClosed(OnnxEvaluator evaluator) { assertFalse(isClosed(evaluator), "Session is closed"); } + private static void assertSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) { + assertSame(evaluator1.ortSession(), evaluator2.ortSession()); + } + private static void assertNotSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) { + assertNotSame(evaluator1.ortSession(), evaluator2.ortSession()); + } + + private static boolean isClosed(OnnxEvaluator evaluator) { + try { + evaluator.getInputs(); + return false; + } catch (IllegalStateException e) { + assertEquals("Asking for inputs from a closed OrtSession.", e.getMessage()); + return true; + } } }
\ No newline at end of file |