diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-27 17:02:23 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-02-27 18:13:08 +0100 |
commit | 5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch) | |
tree | 12f025b12e86e5f9490b74dd2cae68283f779e67 /model-integration/src/test | |
parent | 6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff) |
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances.
Cache underlying `OrtSession` instead of `OnnxEvaluator`.
Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
Diffstat (limited to 'model-integration/src/test')
6 files changed, 85 insertions, 74 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index b06a54d68bb..329b87cacd1 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,13 +1,12 @@ package ai.vespa.embedding; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import java.lang.IllegalArgumentException; import java.util.List; import static org.junit.Assert.assertEquals; @@ -20,12 +19,12 @@ public class BertBaseEmbedderTest { public void testEmbedder() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); - BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); + BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer @@ -39,13 +38,13 @@ public class BertBaseEmbedderTest { public void testEmbedderWithoutTokenTypeIdsName() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); builder.transformerTokenTypeIds(""); - BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); + BertBaseEmbedder embedder = newBertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer @@ -59,14 +58,18 @@ public class BertBaseEmbedderTest { public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); // we did not configured BertBaseEmbedder to accept missing token type ids // so we expect ctor to throw - assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); }); + assertThrows(IllegalArgumentException.class, () -> { newBertBaseEmbedder(builder.build()); }); + } + + private static BertBaseEmbedder newBertBaseEmbedder(BertBaseEmbedderConfig cfg) { + return new BertBaseEmbedder(new OnnxRuntime(), cfg); } } diff --git a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java index c67b6b0dcab..0ff9acc9a69 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java @@ -1,19 +1,5 @@ package ai.vespa.embedding.huggingface; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import com.yahoo.config.ModelReference; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; - -import java.io.IOException; -import java.util.List; - -import static org.junit.Assume.assumeTrue; -import static org.junit.Assert.assertEquals; - public class HuggingFaceEmbedderTest { /* @Test @@ -21,7 +7,7 @@ public class HuggingFaceEmbedderTest { String modelPath = "src/test/models/hf/model.onnx"; String tokenizerPath = "src/test/models/hf/tokenizer.json"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder(); builder.tokenizerPath(ModelReference.valueOf(tokenizerPath)); diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java index 733430aa10d..c22902b344f 100644 --- a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java @@ -1,6 +1,6 @@ package ai.vespa.llm; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.llm.GeneratorConfig; import org.junit.Test; @@ -15,13 +15,13 @@ public class GeneratorTest { String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model"; String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx"; String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx"; - assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath)); + assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath)); GeneratorConfig.Builder builder = new GeneratorConfig.Builder(); builder.tokenizerModel(ModelReference.valueOf(vocabPath)); builder.encoderModel(ModelReference.valueOf(encoderModelPath)); builder.decoderModel(ModelReference.valueOf(decoderModelPath)); - Generator generator = new Generator(builder.build()); + Generator generator = newGenerator(builder.build()); GeneratorOptions options = new GeneratorOptions(); options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY); @@ -33,4 +33,8 @@ public class GeneratorTest { assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result); } + private static Generator newGenerator(GeneratorConfig cfg) { + return new Generator(new OnnxRuntime(), cfg); + } + } diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java deleted file mode 100644 index acce660f466..00000000000 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorCacheTest.java +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.modelintegration.evaluator; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.mock; - -/** - * @author bjorncs - */ -class OnnxEvaluatorCacheTest { - - @Test - void reuses_instance_while_in_use() { - var cache = new OnnxEvaluatorCache((__, ___) -> mock(OnnxEvaluator.class)); - var referencedEvaluator1 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); - var referencedEvaluator2 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); - var referencedEvaluator3 = cache.evaluatorOf("model2", new OnnxEvaluatorOptions()); - assertSame(referencedEvaluator1.evaluator(), referencedEvaluator2.evaluator()); - assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator3.evaluator()); - assertEquals(2, cache.size()); - referencedEvaluator1.close(); - referencedEvaluator2.close(); - assertEquals(1, cache.size()); - referencedEvaluator3.close(); - assertEquals(0, cache.size()); - var referencedEvaluator4 = cache.evaluatorOf("model1", new OnnxEvaluatorOptions()); - assertNotSame(referencedEvaluator1.evaluator(), referencedEvaluator4.evaluator()); - assertEquals(1, cache.size()); - referencedEvaluator4.close(); - assertEquals(0, cache.size()); - } - -}
\ No newline at end of file diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 83f355821e5..5aba54de11b 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -5,23 +5,31 @@ package ai.vespa.modelintegration.evaluator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import org.junit.jupiter.api.BeforeAll; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeTrue; +import static org.junit.Assume.assumeNotNull; /** * @author lesters */ public class OnnxEvaluatorTest { + private static OnnxRuntime runtime; + + @BeforeAll + public static void beforeAll() { + if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime(); + } + @Test public void testSimpleModel() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx"); + assumeNotNull(runtime); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx"); // Input types Map<String, TensorType> inputTypes = evaluator.getInputInfo(); @@ -45,8 +53,8 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx"); + assumeNotNull(runtime); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types Map<String, TensorType> inputTypes = evaluator.getInputInfo(); @@ -64,7 +72,7 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeNotNull(runtime); String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; @@ -73,7 +81,7 @@ public class OnnxEvaluatorTest { @Test public void testTypes() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeNotNull(runtime); assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); @@ -86,8 +94,8 @@ public class OnnxEvaluatorTest { @Test public void testNotIdentifiers() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/badnames.onnx"); + assumeNotNull(runtime); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx"); var inputInfo = evaluator.getInputInfo(); var outputInfo = evaluator.getOutputInfo(); for (var entry : inputInfo.entrySet()) { @@ -152,7 +160,7 @@ public class OnnxEvaluatorTest { } private void assertEvaluate(String model, String output, String... input) { - OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model); + OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model); Map<String, Tensor> inputs = new HashMap<>(); for (int i = 0; i < input.length; ++i) { inputs.put("input" + (i+1), Tensor.from(input[i])); diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java new file mode 100644 index 00000000000..81b1237e770 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.modelintegration.evaluator; + +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** + * @author bjorncs + */ +class OnnxRuntimeTest { + + @Test + void reuses_sessions_while_active() throws OrtException { + var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class)); + var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); + var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); + var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false); + assertSame(session1.instance(), session2.instance()); + assertNotSame(session1.instance(), session3.instance()); + assertEquals(2, runtime.sessionsCached()); + + session1.close(); + session2.close(); + assertEquals(1, runtime.sessionsCached()); + verify(session1.instance()).close(); + verify(session3.instance(), never()).close(); + + session3.close(); + assertEquals(0, runtime.sessionsCached()); + verify(session3.instance()).close(); + + var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); + assertNotSame(session1.instance(), session4.instance()); + assertEquals(1, runtime.sessionsCached()); + session4.close(); + assertEquals(0, runtime.sessionsCached()); + verify(session4.instance()).close(); + } +}
\ No newline at end of file |