summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java125
-rw-r--r--model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnxbin0 -> 13011175 bytes
-rw-r--r--model-integration/src/test/models/onnx/transformer/tokenizer.json175
3 files changed, 300 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
new file mode 100644
index 00000000000..70f91eb44ad
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -0,0 +1,125 @@
+package ai.vespa.embedding;
+
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
+import com.yahoo.config.ModelReference;
+import com.yahoo.embedding.ColBertEmbedderConfig;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MixedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assume.assumeTrue;
+
+public class ColBertEmbedderTest {
+
+ @Test
+ public void testPacking() {
+ assertPackedRight(
+ "" +
+ "tensor<float>(d1[6],d2[8]):" +
+ "[" +
+ "[0, 0, 0, 0, 0, 0, 0, 1]," +
+ "[0, 0, 0, 0, 0, 1, 0, 1]," +
+ "[0, 0, 0, 0, 0, 0, 1, 1]," +
+ "[0, 1, 1, 1, 1, 1, 1, 1]," +
+ "[1, 0, 0, 0, 0, 0, 0, 0]," +
+ "[1, 1, 1, 1, 1, 1, 1, 1]" +
+ "]",
+ TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
+ "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}"
+ );
+ assertPackedRight(
+ "" +
+ "tensor<float>(d1[2],d2[16]):" +
+ "[" +
+ "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," +
+ "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
+ "]",
+ TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
+ "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}"
+ );
+ }
+
+ @Test
+ public void testEmbedder() {
+ assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext);
+ assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext);
+ assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext);
+
+ assertThrows(IllegalArgumentException.class, () -> {
+ //throws because int8 is not supported for query context
+ assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext);
+ });
+ assertThrows(IllegalArgumentException.class, () -> {
+ //throws because 16 is less than model output (128) and we want float
+ assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext);
+ });
+
+ assertThrows(IllegalArgumentException.class, () -> {
+ //throws because 128/8 does not fit into 15
+ assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext);
+ });
+ }
+
+ @Test
+ public void testLenghtLimits() {
+ StringBuilder sb = new StringBuilder();
+ for(int i = 0; i < 1024; i++) {
+ sb.append("annoyance");
+ sb.append(" ");
+ }
+ String text = sb.toString();
+ Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
+ assertEquals(512*128,fullFloat.size());
+
+ Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
+ assertEquals(32*128,query.size());
+
+ Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
+ assertEquals(512*16,binaryRep.size());
+
+ Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
+ // 4 tokens, 16 bytes each = 64 bytes
+ //because of CLS, special, sequence, SEP
+ assertEquals(4*16,shortDoc.size());;
+ }
+
+ static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
+ TensorType destType = TensorType.fromSpec(tensorSpec);
+ Tensor result = embedder.embed(text, context, destType);
+ assertEquals(destType,result.type());
+ MixedTensor mixedTensor = (MixedTensor) result;
+ if(context == queryContext) {
+ assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size());
+ }
+ return result;
+ }
+
+ static void assertPackedRight(String numbers, TensorType destination,String expected) {
+ Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination);
+ assertEquals(expected,packed.toString());
+ }
+
+ static final Embedder embedder;
+ static final Embedder.Context indexingContext;
+ static final Embedder.Context queryContext;
+ static {
+ indexingContext = new Embedder.Context("schema.indexing");
+ queryContext = new Embedder.Context("query(qt)");
+ embedder = getEmbedder();
+ }
+ private static Embedder getEmbedder() {
+ String vocabPath = "src/test/models/onnx/transformer/tokenizer.json";
+ String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
+ ColBertEmbedderConfig.Builder builder = new ColBertEmbedderConfig.Builder();
+ builder.tokenizerPath(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ builder.transformerGpuDevice(-1);
+ return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx
new file mode 100644
index 00000000000..5ab1060e59e
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/transformer/tokenizer.json b/model-integration/src/test/models/onnx/transformer/tokenizer.json
new file mode 100644
index 00000000000..28340f289bb
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/tokenizer.json
@@ -0,0 +1,175 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [
+ {
+ "id": 0,
+ "content": "[PAD]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 100,
+ "content": "[UNK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 101,
+ "content": "[CLS]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 102,
+ "content": "[SEP]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 103,
+ "content": "[MASK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ }
+ ],
+ "normalizer": {
+ "type": "BertNormalizer",
+ "clean_text": true,
+ "handle_chinese_chars": true,
+ "strip_accents": null,
+ "lowercase": true
+ },
+ "pre_tokenizer": {
+ "type": "BertPreTokenizer"
+ },
+ "post_processor": {
+ "type": "TemplateProcessing",
+ "single": [
+ {
+ "SpecialToken": {
+ "id": "[CLS]",
+ "type_id": 0
+ }
+ },
+ {
+ "Sequence": {
+ "id": "A",
+ "type_id": 0
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "[SEP]",
+ "type_id": 0
+ }
+ }
+ ],
+ "pair": [
+ {
+ "SpecialToken": {
+ "id": "[CLS]",
+ "type_id": 0
+ }
+ },
+ {
+ "Sequence": {
+ "id": "A",
+ "type_id": 0
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "[SEP]",
+ "type_id": 0
+ }
+ },
+ {
+ "Sequence": {
+ "id": "B",
+ "type_id": 1
+ }
+ },
+ {
+ "SpecialToken": {
+ "id": "[SEP]",
+ "type_id": 1
+ }
+ }
+ ],
+ "special_tokens": {
+ "[CLS]": {
+ "id": "[CLS]",
+ "ids": [101],
+ "tokens": ["[CLS]"]
+ },
+ "[SEP]": {
+ "id": "[SEP]",
+ "ids": [102],
+ "tokens": ["[SEP]"]
+ }
+ }
+ },
+ "decoder": {
+ "type": "WordPiece",
+ "prefix": "##",
+ "cleanup": true
+ },
+ "model": {
+ "type": "WordPiece",
+ "unk_token": "[UNK]",
+ "continuing_subword_prefix": "##",
+ "max_input_chars_per_word": 100,
+ "vocab": {
+ "[PAD]": 0,
+ "[unused0]": 1,
+ "[unused1]": 2,
+ "[UNK]": 100,
+ "[CLS]": 101,
+ "[SEP]": 102,
+ "[MASK]": 103,
+ "a": 1037,
+ "b": 1038,
+ "c": 1039,
+ "d": 1040,
+ "e": 1041,
+ "f": 1042,
+ "g": 1043,
+ "h": 1044,
+ "i": 1045,
+ "j": 1046,
+ "k": 1047,
+ "l": 1048,
+ "m": 1049,
+ "n": 1050,
+ "o": 1051,
+ "p": 1052,
+ "q": 1053,
+ "r": 1054,
+ "s": 1055,
+ "t": 1056,
+ "u": 1057,
+ "v": 1058,
+ "w": 1059,
+ "x": 1060,
+ "y": 1061,
+ "z": 1062
+ }
+ }
+}