summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-31 10:28:47 +0200
committerGitHub <noreply@github.com>2023-03-31 10:28:47 +0200
commitd65d548169183b47b931b3c5e39ad5a8fae06ce5 (patch)
tree7bff2d13a232465f53d28ef1ea877f16799a0dd0
parentac597f0921fe837578f22037d2ee1e557d7d3099 (diff)
parentcb0059e6e3d8c7abec74e92a5f69f33714297917 (diff)
Merge pull request #26654 from vespa-engine/bjorncs/onnx-model-from-bytes
Bjorncs/onnx model from bytes
-rw-r--r--model-integration/pom.xml5
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java70
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java53
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java95
5 files changed, 185 insertions, 57 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 9bb60827a68..c27ed9d2c31 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -106,6 +106,11 @@
</dependency>
<dependency>
+ <groupId>org.lz4</groupId>
+ <artifactId>lz4-java</artifactId>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 7cdc27b6d63..02fa7b68dc4 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -7,6 +7,7 @@ import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime.ModelPathOrData;
import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -28,7 +29,11 @@ public class OnnxEvaluator implements AutoCloseable {
private final ReferencedOrtSession session;
OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
- session = createSession(modelPath, runtime, options, true);
+ session = createSession(ModelPathOrData.of(modelPath), runtime, options, true);
+ }
+
+ OnnxEvaluator(byte[] data, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
+ session = createSession(ModelPathOrData.of(data), runtime, options, true);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
@@ -125,19 +130,20 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
- private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static ReferencedOrtSession createSession(
+ ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu());
+ return runtime.acquireSession(model, options, tryCuda && options.requestingGpu());
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
- throw new IllegalArgumentException("No such file: " + modelPath);
+ throw new IllegalArgumentException("No such file: " + model.path().get());
}
if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
// Failed in CUDA native code, but GPU device is optional, so we can proceed without it
- return createSession(modelPath, runtime, options, false);
+ return createSession(model, runtime, options, false);
}
if (isCudaError(e)) {
throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e);
@@ -146,6 +152,9 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
+ // For unit testing
+ OrtSession ortSession() { return session.instance(); }
+
private String mapToInternalName(String outputName) throws OrtException {
var info = session.instance().getOutputInfo();
var internalNames = info.keySet();
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
index 42830041c02..ece1db55c1e 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -10,9 +10,15 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
import com.yahoo.jdisc.refcount.References;
+import net.jpountz.xxhash.XXHashFactory;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -26,14 +32,22 @@ import static com.yahoo.yolean.Exceptions.throwUnchecked;
public class OnnxRuntime extends AbstractComponent {
// For unit testing
- @FunctionalInterface interface OrtSessionFactory {
+ interface OrtSessionFactory {
OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException;
+ OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException;
}
private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName());
private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment();
- private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts);
+ private static final OrtSessionFactory defaultFactory = new OrtSessionFactory() {
+ @Override public OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException {
+ return ortEnvironment().createSession(path, opts);
+ }
+ @Override public OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException {
+ return ortEnvironment().createSession(data, opts);
+ }
+ };
private final Object monitor = new Object();
private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
@@ -43,6 +57,14 @@ public class OnnxRuntime extends AbstractComponent {
OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
+ public OnnxEvaluator evaluatorOf(byte[] model) {
+ return new OnnxEvaluator(model, null, this);
+ }
+
+ public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
+ return new OnnxEvaluator(model, options, this);
+ }
+
public OnnxEvaluator evaluatorOf(String modelPath) {
return new OnnxEvaluator(modelPath, null, this);
}
@@ -105,8 +127,8 @@ public class OnnxRuntime extends AbstractComponent {
};
}
- ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
- var sessionId = new OrtSessionId(modelPath, options, loadCuda);
+ ReferencedOrtSession acquireSession(ModelPathOrData model, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
+ var sessionId = new OrtSessionId(calculateModelHash(model), options, loadCuda);
synchronized (monitor) {
var sharedSession = sessions.get(sessionId);
if (sharedSession != null) {
@@ -114,8 +136,9 @@ public class OnnxRuntime extends AbstractComponent {
}
}
+ var opts = options.getOptions(loadCuda);
// Note: identical models loaded simultaneously will result in duplicate session instances
- var session = factory.create(modelPath, options.getOptions(loadCuda));
+ var session = model.path().isPresent() ? factory.create(model.path().get(), opts) : factory.create(model.data().get(), opts);
log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session)));
var sharedSession = new SharedOrtSession(sessionId, session);
@@ -125,25 +148,52 @@ public class OnnxRuntime extends AbstractComponent {
return referencedSession;
}
+ private static long calculateModelHash(ModelPathOrData model) {
+ if (model.path().isPresent()) {
+ try (var hasher = XXHashFactory.fastestInstance().newStreamingHash64(0);
+ var in = Files.newInputStream(Paths.get(model.path().get()))) {
+ byte[] buffer = new byte[8192];
+ int bytesRead;
+ while ((bytesRead = in.read(buffer)) != -1) {
+ hasher.update(buffer, 0, bytesRead);
+ }
+ return hasher.getValue();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ } else {
+ var data = model.data().get();
+ return XXHashFactory.fastestInstance().hash64().hash(data, 0, data.length, 0);
+ }
+ }
+
int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
- public static class ReferencedOrtSession implements AutoCloseable {
+ static class ReferencedOrtSession implements AutoCloseable {
private final OrtSession instance;
private final ResourceReference ref;
- public ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
+ ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
this.instance = instance;
this.ref = ref;
}
- public OrtSession instance() { return instance; }
+ OrtSession instance() { return instance; }
@Override public void close() { ref.close(); }
}
+ record ModelPathOrData(Optional<String> path, Optional<byte[]> data) {
+ static ModelPathOrData of(String path) { return new ModelPathOrData(Optional.of(path), Optional.empty()); }
+ static ModelPathOrData of(byte[] data) { return new ModelPathOrData(Optional.empty(), Optional.of(data)); }
+ ModelPathOrData {
+ if (path.isEmpty() == data.isEmpty()) throw new IllegalArgumentException("Either path or data must be non-empty");
+ }
+ }
+
// Assumes options are never modified after being stored in `onnxSessions`
- record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {}
+ private record OrtSessionId(long modelHash, OnnxEvaluatorOptions options, boolean loadCuda) {}
- record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {}
+ private record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {}
private class SharedOrtSession {
private final OrtSessionId id;
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 5aba54de11b..5a367ef83e4 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -5,30 +5,26 @@ package ai.vespa.modelintegration.evaluator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import org.junit.jupiter.api.BeforeAll;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
- private static OnnxRuntime runtime;
-
- @BeforeAll
- public static void beforeAll() {
- if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime();
- }
-
@Test
public void testSimpleModel() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx");
// Input types
@@ -53,7 +49,8 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
@@ -72,21 +69,23 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
- assertEvaluate("simple/matmul.onnx", expected, input1, input2);
+ assertEvaluate(runtime, "simple/matmul.onnx", expected, input1, input2);
}
@Test
public void testTypes() {
- assumeNotNull(runtime);
- assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
- assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
- assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
- assertEvaluate("cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
- assertEvaluate("cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
+ assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
+ assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
+ assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
+ assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
// ONNX Runtime 1.8.0 does not support much of bfloat16 yet
// assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
@@ -94,7 +93,8 @@ public class OnnxEvaluatorTest {
@Test
public void testNotIdentifiers() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx");
var inputInfo = evaluator.getInputInfo();
var outputInfo = evaluator.getOutputInfo();
@@ -159,7 +159,18 @@ public class OnnxEvaluatorTest {
assertEquals(3, allResults.size());
}
- private void assertEvaluate(String model, String output, String... input) {
+ @Test
+ public void testLoadModelFromBytes() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ var model = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx"));
+ var evaluator = runtime.evaluatorOf(model);
+ assertEquals(3, evaluator.getInputs().size());
+ assertEquals(1, evaluator.getOutputs().size());
+ evaluator.close();
+ }
+
+ private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) {
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < input.length; ++i) {
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
index 81b1237e770..fdbd4fa4e5c 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
@@ -2,16 +2,18 @@
package ai.vespa.modelintegration.evaluator;
-import ai.onnxruntime.OrtException;
-import ai.onnxruntime.OrtSession;
import org.junit.jupiter.api.Test;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assumptions.assumeTrue;
/**
* @author bjorncs
@@ -19,30 +21,81 @@ import static org.mockito.Mockito.verify;
class OnnxRuntimeTest {
@Test
- void reuses_sessions_while_active() throws OrtException {
- var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class));
- var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
- var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
- var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false);
- assertSame(session1.instance(), session2.instance());
- assertNotSame(session1.instance(), session3.instance());
+ void reuses_sessions_while_active() {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ OnnxRuntime runtime = new OnnxRuntime();
+ String model1 = "src/test/models/onnx/simple/simple.onnx";
+ var evaluator1 = runtime.evaluatorOf(model1);
+ var evaluator2 = runtime.evaluatorOf(model1);
+ String model2 = "src/test/models/onnx/simple/matmul.onnx";
+ var evaluator3 = runtime.evaluatorOf(model2);
+ assertSameSession(evaluator1, evaluator2);
+ assertNotSameSession(evaluator1, evaluator3);
assertEquals(2, runtime.sessionsCached());
- session1.close();
- session2.close();
+ evaluator1.close();
+ evaluator2.close();
assertEquals(1, runtime.sessionsCached());
- verify(session1.instance()).close();
- verify(session3.instance(), never()).close();
+ assertClosed(evaluator1);
+ assertNotClosed(evaluator3);
- session3.close();
+ evaluator3.close();
assertEquals(0, runtime.sessionsCached());
- verify(session3.instance()).close();
+ assertClosed(evaluator3);
- var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
- assertNotSame(session1.instance(), session4.instance());
+ var session4 = runtime.evaluatorOf(model1);
+ assertNotSameSession(evaluator1, session4);
assertEquals(1, runtime.sessionsCached());
session4.close();
assertEquals(0, runtime.sessionsCached());
- verify(session4.instance()).close();
+ assertClosed(session4);
+ }
+
+ @Test
+ void loads_model_from_byte_array() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ byte[] bytes = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx"));
+ var evaluator1 = runtime.evaluatorOf(bytes);
+ var evaluator2 = runtime.evaluatorOf(bytes);
+ assertEquals(3, evaluator1.getInputs().size());
+ assertEquals(1, runtime.sessionsCached());
+ assertSameSession(evaluator1, evaluator2);
+ evaluator2.close();
+ evaluator1.close();
+ assertEquals(0, runtime.sessionsCached());
+ assertClosed(evaluator1);
+ }
+
+ @Test
+ void loading_same_model_from_bytes_and_file_resolve_to_same_instance() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ String modelPath = "src/test/models/onnx/simple/simple.onnx";
+ byte[] bytes = Files.readAllBytes(Paths.get(modelPath));
+ try (var evaluator1 = runtime.evaluatorOf(bytes);
+ var evaluator2 = runtime.evaluatorOf(modelPath)) {
+ assertSameSession(evaluator1, evaluator2);
+ assertEquals(1, runtime.sessionsCached());
+ }
+ }
+
+ private static void assertClosed(OnnxEvaluator evaluator) { assertTrue(isClosed(evaluator), "Session is not closed"); }
+ private static void assertNotClosed(OnnxEvaluator evaluator) { assertFalse(isClosed(evaluator), "Session is closed"); }
+ private static void assertSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) {
+ assertSame(evaluator1.ortSession(), evaluator2.ortSession());
+ }
+ private static void assertNotSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) {
+ assertNotSame(evaluator1.ortSession(), evaluator2.ortSession());
+ }
+
+ private static boolean isClosed(OnnxEvaluator evaluator) {
+ try {
+ evaluator.getInputs();
+ return false;
+ } catch (IllegalStateException e) {
+ assertEquals("Asking for inputs from a closed OrtSession.", e.getMessage());
+ return true;
+ }
}
} \ No newline at end of file