diff options
Diffstat (limited to 'model-integration/src/test')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 125 | ||||
-rw-r--r-- | model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx | bin | 0 -> 13011175 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/transformer/tokenizer.json | 175 |
3 files changed, 300 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java new file mode 100644 index 00000000000..70f91eb44ad --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -0,0 +1,125 @@ +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.config.ModelReference; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MixedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; + +public class ColBertEmbedderTest { + + @Test + public void testPacking() { + assertPackedRight( + "" + + "tensor<float>(d1[6],d2[8]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1]," + + "[0, 0, 0, 0, 0, 1, 0, 1]," + + "[0, 0, 0, 0, 0, 0, 1, 1]," + + "[0, 1, 1, 1, 1, 1, 1, 1]," + + "[1, 0, 0, 0, 0, 0, 0, 0]," + + "[1, 1, 1, 1, 1, 1, 1, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[1])"), + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}" + ); + assertPackedRight( + "" + + "tensor<float>(d1[2],d2[16]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[2])"), + "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}" + ); + } + + @Test + public void testEmbedder() { + assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext); + assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext); + assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because int8 is not supported for query context + assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); + }); + assertThrows(IllegalArgumentException.class, () -> { + //throws because 16 is less than model output (128) and we want float + assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); + }); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because 128/8 does not fit into 15 + assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); + }); + } + + @Test + public void testLenghtLimits() { + StringBuilder sb = new StringBuilder(); + for(int i = 0; i < 1024; i++) { + sb.append("annoyance"); + sb.append(" "); + } + String text = sb.toString(); + Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + assertEquals(512*128,fullFloat.size()); + + Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); + assertEquals(32*128,query.size()); + + Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); + assertEquals(512*16,binaryRep.size()); + + Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); + // 4 tokens, 16 bytes each = 64 bytes + //because of CLS, special, sequence, SEP + assertEquals(4*16,shortDoc.size());; + } + + static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { + TensorType destType = TensorType.fromSpec(tensorSpec); + Tensor result = embedder.embed(text, context, destType); + assertEquals(destType,result.type()); + MixedTensor mixedTensor = (MixedTensor) result; + if(context == queryContext) { + assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); + } + return result; + } + + static void assertPackedRight(String numbers, TensorType destination,String expected) { + Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination); + assertEquals(expected,packed.toString()); + } + + static final Embedder embedder; + static final Embedder.Context indexingContext; + static final Embedder.Context queryContext; + static { + indexingContext = new Embedder.Context("schema.indexing"); + queryContext = new Embedder.Context("query(qt)"); + embedder = getEmbedder(); + } + private static Embedder getEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/tokenizer.json"; + String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); + ColBertEmbedderConfig.Builder builder = new ColBertEmbedderConfig.Builder(); + builder.tokenizerPath(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerGpuDevice(-1); + return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } +}
\ No newline at end of file diff --git a/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx Binary files differnew file mode 100644 index 00000000000..5ab1060e59e --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx diff --git a/model-integration/src/test/models/onnx/transformer/tokenizer.json b/model-integration/src/test/models/onnx/transformer/tokenizer.json new file mode 100644 index 00000000000..28340f289bb --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/tokenizer.json @@ -0,0 +1,175 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 100, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 101, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 102, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 103, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": { + "type": "BertNormalizer", + "clean_text": true, + "handle_chinese_chars": true, + "strip_accents": null, + "lowercase": true + }, + "pre_tokenizer": { + "type": "BertPreTokenizer" + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 1 + } + } + ], + "special_tokens": { + "[CLS]": { + "id": "[CLS]", + "ids": [101], + "tokens": ["[CLS]"] + }, + "[SEP]": { + "id": "[SEP]", + "ids": [102], + "tokens": ["[SEP]"] + } + } + }, + "decoder": { + "type": "WordPiece", + "prefix": "##", + "cleanup": true + }, + "model": { + "type": "WordPiece", + "unk_token": "[UNK]", + "continuing_subword_prefix": "##", + "max_input_chars_per_word": 100, + "vocab": { + "[PAD]": 0, + "[unused0]": 1, + "[unused1]": 2, + "[UNK]": 100, + "[CLS]": 101, + "[SEP]": 102, + "[MASK]": 103, + "a": 1037, + "b": 1038, + "c": 1039, + "d": 1040, + "e": 1041, + "f": 1042, + "g": 1043, + "h": 1044, + "i": 1045, + "j": 1046, + "k": 1047, + "l": 1048, + "m": 1049, + "n": 1050, + "o": 1051, + "p": 1052, + "q": 1053, + "r": 1054, + "s": 1055, + "t": 1056, + "u": 1057, + "v": 1058, + "w": 1059, + "x": 1060, + "y": 1061, + "z": 1062 + } + } +} |