aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 17:02:23 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 18:13:08 +0100
commit5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch)
tree12f025b12e86e5f9490b74dd2cae68283f779e67 /model-integration/src/test
parent6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff)
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances. Cache underlying `OrtSession` instead of `OnnxEvaluator`. Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java19
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java16
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java10
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java38
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java28
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java48
6 files changed, 85 insertions, 74 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index b06a54d68bb..329b87cacd1 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,13 +1,12 @@
package ai.vespa.embedding;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import java.lang.IllegalArgumentException;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -20,12 +19,12 @@ public class BertBaseEmbedderTest {
public void testEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
- BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+ BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
@@ -39,13 +38,13 @@ public class BertBaseEmbedderTest {
public void testEmbedderWithoutTokenTypeIdsName() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
builder.transformerTokenTypeIds("");
- BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+ BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
@@ -59,14 +58,18 @@ public class BertBaseEmbedderTest {
public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
// we did not configured BertBaseEmbedder to accept missing token type ids
// so we expect ctor to throw
- assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); });
+ assertThrows(IllegalArgumentException.class, () -> { newBertBaseEmbedder(builder.build()); });
+ }
+
+ private static BertBaseEmbedder newBertBaseEmbedder(BertBaseEmbedderConfig cfg) {
+ return new BertBaseEmbedder(new OnnxRuntime(), cfg);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index c67b6b0dcab..0ff9acc9a69 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -1,19 +1,5 @@
package ai.vespa.embedding.huggingface;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import com.yahoo.config.ModelReference;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
-
-import java.io.IOException;
-import java.util.List;
-
-import static org.junit.Assume.assumeTrue;
-import static org.junit.Assert.assertEquals;
-
public class HuggingFaceEmbedderTest {
/*
@Test
@@ -21,7 +7,7 @@ public class HuggingFaceEmbedderTest {
String modelPath = "src/test/models/hf/model.onnx";
String tokenizerPath = "src/test/models/hf/tokenizer.json";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
builder.tokenizerPath(ModelReference.valueOf(tokenizerPath));
diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
index 733430aa10d..c22902b344f 100644
--- a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
@@ -1,6 +1,6 @@
package ai.vespa.llm;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.llm.GeneratorConfig;
import org.junit.Test;
@@ -15,13 +15,13 @@ public class GeneratorTest {
String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
- assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath));
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));
GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
builder.tokenizerModel(ModelReference.valueOf(vocabPath));
builder.encoderModel(ModelReference.valueOf(encoderModelPath));
builder.decoderModel(ModelReference.valueOf(decoderModelPath));
- Generator generator = new Generator(builder.build());
+ Generator generator = newGenerator(builder.build());
GeneratorOptions options = new GeneratorOptions();
options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
@@ -33,4 +33,8 @@ public class GeneratorTest {
assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
}
+ private static Generator newGenerator(GeneratorConfig cfg) {
+ return new Generator(new OnnxRuntime(), cfg);
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
deleted file mode 100644
index acce660f466..00000000000
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.modelintegration.evaluator;
-
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotSame;
-import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.mockito.Mockito.mock;
-
-/**
- * @author bjorncs
- */
-class OnnxEvaluatorCacheTest {
-
- @Test
- void reuses_instance_while_in_use() {
- var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class));
- var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions());
- assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator());
- assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator());
- assertEquals(2, cache.size());
- referencedEvaluator1.close();
- referencedEvaluator2.close();
- assertEquals(1, cache.size());
- referencedEvaluator3.close();
- assertEquals(0, cache.size());
- var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions());
- assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator());
- assertEquals(1, cache.size());
- referencedEvaluator4.close();
- assertEquals(0, cache.size());
- }
-
-} \ No newline at end of file
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 83f355821e5..5aba54de11b 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -5,23 +5,31 @@ package ai.vespa.modelintegration.evaluator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import org.junit.jupiter.api.BeforeAll;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeTrue;
+import static org.junit.Assume.assumeNotNull;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
+ private static OnnxRuntime runtime;
+
+ @BeforeAll
+ public static void beforeAll() {
+ if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime();
+ }
+
@Test
public void testSimpleModel() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx");
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -45,8 +53,8 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
Map<String, TensorType> inputTypes = evaluator.getInputInfo();
@@ -64,7 +72,7 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeNotNull(runtime);
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
@@ -73,7 +81,7 @@ public class OnnxEvaluatorTest {
@Test
public void testTypes() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeNotNull(runtime);
assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
@@ -86,8 +94,8 @@ public class OnnxEvaluatorTest {
@Test
public void testNotIdentifiers() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx");
+ assumeNotNull(runtime);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx");
var inputInfo = evaluator.getInputInfo();
var outputInfo = evaluator.getOutputInfo();
for (var entry : inputInfo.entrySet()) {
@@ -152,7 +160,7 @@ public class OnnxEvaluatorTest {
}
private void assertEvaluate(String model, String output, String... input) {
- OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model);
+ OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < input.length; ++i) {
inputs.put("input" + (i+1), Tensor.from(input[i]));
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
new file mode 100644
index 00000000000..81b1237e770
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author bjorncs
+ */
+class OnnxRuntimeTest {
+
+ @Test
+ void reuses_sessions_while_active() throws OrtException {
+ var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class));
+ var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false);
+ assertSame(session1.instance(), session2.instance());
+ assertNotSame(session1.instance(), session3.instance());
+ assertEquals(2, runtime.sessionsCached());
+
+ session1.close();
+ session2.close();
+ assertEquals(1, runtime.sessionsCached());
+ verify(session1.instance()).close();
+ verify(session3.instance(), never()).close();
+
+ session3.close();
+ assertEquals(0, runtime.sessionsCached());
+ verify(session3.instance()).close();
+
+ var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
+ assertNotSame(session1.instance(), session4.instance());
+ assertEquals(1, runtime.sessionsCached());
+ session4.close();
+ assertEquals(0, runtime.sessionsCached());
+ verify(session4.instance()).close();
+ }
+} \ No newline at end of file