diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-09-21 12:02:08 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-09-21 12:02:08 +0200 |
commit | d4692ee4fe82f34958679f0f87777a6e5c23d8db (patch) | |
tree | 3c070c4c5a624482d7639561e972544d8e456fde /model-integration | |
parent | 7faeffcc5901ae88c1c3d1814665d0db6ca1d900 (diff) |
Add ColBERT embedder
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 299 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 125 | ||||
-rw-r--r-- | model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx | bin | 0 -> 13011175 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/transformer/tokenizer.json | 175 |
4 files changed, 599 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java new file mode 100644 index 00000000000..5c3b18e2949 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -0,0 +1,299 @@ +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import com.yahoo.api.annotations.Beta; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.huggingface.HuggingFaceTokenizer; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import java.nio.file.Paths; +import java.util.Map; +import java.util.List; +import java.util.ArrayList; +import java.util.Set; +import java.util.HashSet; +import java.util.BitSet; +import java.util.Arrays; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; + +/** + * A ColBERT embedder implementation that maps text to multiple vectors, one vector per subword id. + * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. + * + * See col-bert-embedder.def for configurable parameters. + * @author bergum + */ +@Beta +public class ColBertEmbedder extends AbstractComponent implements Embedder { + private final Embedder.Runtime runtime; + private final String inputIdsName; + private final String attentionMaskName; + + private final String outputName; + + private final HuggingFaceTokenizer tokenizer; + private final OnnxEvaluator evaluator; + + private final int maxTransformerTokens; + private final int maxQueryTokens; + private final int maxDocumentTokens; + + private final long startSequenceToken; + private final long endSequenceToken; + private final long maskSequenceToken; + + + @Inject + public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) { + this.runtime = runtime; + inputIdsName = config.transformerInputIds(); + attentionMaskName = config.transformerAttentionMask(); + outputName = config.transformerOutput(); + maxTransformerTokens = config.transformerMaxTokens(); + if(config.maxDocumentTokens() > maxTransformerTokens) + throw new IllegalArgumentException("maxDocumentTokens must be less than or equal to transformerMaxTokens"); + maxDocumentTokens = config.maxDocumentTokens(); + maxQueryTokens = config.maxQueryTokens(); + startSequenceToken = config.transformerStartSequenceToken(); + endSequenceToken = config.transformerEndSequenceToken(); + maskSequenceToken = config.transformerMaskToken(); + + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); + var onnxOpts = new OnnxEvaluatorOptions(); + + if (config.transformerGpuDevice() >= 0) + onnxOpts.setGpuDevice(config.transformerGpuDevice()); + onnxOpts.setExecutionMode(config.transformerExecutionMode().toString()); + onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts); + validateModel(); + } + + public void validateModel() { + Map<String, TensorType> inputs = evaluator.getInputInfo(); + validateName(inputs, inputIdsName, "input"); + validateName(inputs, attentionMaskName, "input"); + Map<String, TensorType> outputs = evaluator.getOutputInfo(); + validateName(outputs, outputName, "output"); + } + + private void validateName(Map<String, TensorType> types, String name, String type) { + if (!types.containsKey(name)) { + throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + + "Model contains: " + String.join(",", types.keySet())); + } + } + + @Override + public List<Integer> embed(String text, Context context) { + throw new UnsupportedOperationException("This embedder only supports embed with tensor type"); + } + + @Override + public Tensor embed(String text, Context context, TensorType tensorType) { + if(!verifyTensorType(tensorType)) { + throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination." + + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType.toString()); + } + if (context.getDestination().startsWith("query")) { + return embedQuery(text, context, tensorType); + } else { + return embedDocument(text, context, tensorType); + } + } + + protected Tensor embedQuery(String text, Context context, TensorType tensorType) { + if(tensorType.valueType() == TensorType.Value.INT8) + throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); + + long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document. + + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + List<Long> ids = encoding.ids(); + if (ids.size() > maxQueryTokens - 3) + ids = ids.subList(0, maxQueryTokens - 3); + + List<Long> inputIds = new ArrayList<>(maxQueryTokens); + List<Long> attentionMask = new ArrayList<>(maxQueryTokens); + + inputIds.add(startSequenceToken); + inputIds.add(Q_TOKEN_ID); + inputIds.addAll(ids); + inputIds.add(endSequenceToken); + int length = inputIds.size(); + + int padding = maxQueryTokens - length; + for (int i = 0; i < padding; i++) + inputIds.add(maskSequenceToken); + + for (int i = 0; i < length; i++) + attentionMask.add((long) 1); + for (int i = 0; i < padding; i++) + attentionMask.add((long) 0);//Do not attend to mask paddings + + Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + + var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0")); + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + Tensor tokenEmbeddings = outputs.get(outputName); + IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + + int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); + if(dims != result.shape()[1]) { + throw new IllegalArgumentException("Token dimensionality does not" + + " match indexed dimensionality of " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(tensorType); + for (int token = 0; token < result.shape()[0]; token++) + for (int d = 0; d < result.shape()[1]; d++) + builder.cell(TensorAddress.of(token, d), result.get(TensorAddress.of(token, d))); + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return builder.build(); + } + + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + long D_TOKEN_ID = 2; // [unused1] token id used during training to differentiate query versus document. + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + List<Long> ids = encoding.ids().stream().filter(token + -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList(); + ; + + if (ids.size() > maxDocumentTokens - 3) + ids = ids.subList(0, maxDocumentTokens - 3); + List<Long> inputIds = new ArrayList<>(maxDocumentTokens); + List<Long> attentionMask = new ArrayList<>(maxDocumentTokens); + inputIds.add(startSequenceToken); + inputIds.add(D_TOKEN_ID); + inputIds.addAll(ids); + inputIds.add(endSequenceToken); + for (int i = 0; i < inputIds.size(); i++) + attentionMask.add((long) 1); + + Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + + var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0")); + + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + Tensor tokenEmbeddings = outputs.get(outputName); + IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + Tensor contextualEmbeddings; + if(tensorType.valueType() == TensorType.Value.INT8) { + contextualEmbeddings = toBitTensor(result, tensorType); + } else { + contextualEmbeddings = toFloatTensor(result, tensorType); + } + + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return contextualEmbeddings; + } + + public static Tensor toFloatTensor(IndexedTensor result, TensorType type) { + int size = type.indexedSubtype().dimensions().size(); + if (size != 1) + throw new IllegalArgumentException("Indexed tensor must have one dimension"); + int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int)result.shape()[1]; + if(resultDim != dims) { + throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDim + + " + dimensions into tensor with " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(type); + for (int token = 0; token < result.shape()[0]; token++) { + for (int d = 0; d < result.shape()[1]; d++) { + var value = result.get(TensorAddress.of(token, d)); + builder.cell(TensorAddress.of(token,d),value); + } + } + return builder.build(); + } + + public static Tensor toBitTensor(IndexedTensor result, TensorType type) { + if (type.valueType() != TensorType.Value.INT8) + throw new IllegalArgumentException("Only a int8 tensor type can be" + + " the destination of bit packing"); + int size = type.indexedSubtype().dimensions().size(); + if (size != 1) + throw new IllegalArgumentException("Indexed tensor must have one dimension"); + int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int)result.shape()[1]; + if(resultDim/8 != dims) { + throw new IllegalArgumentException("Not possible to pack " + resultDim + + " + dimensions into " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(type); + for (int token = 0; token < result.shape()[0]; token++) { + BitSet bitSet = new BitSet(8); + int key = 0; + for (int d = 0; d < result.shape()[1]; d++) { + var value = result.get(TensorAddress.of(token, d)); + int bitIndex = 7 - (d % 8); + if (value > 0.0) { + bitSet.set(bitIndex); + } else { + bitSet.clear(bitIndex); + } + if ((d + 1) % 8 == 0) { + byte[] bytes = bitSet.toByteArray(); + byte packed = (bytes.length == 0) ? 0 : bytes[0]; + builder.cell(TensorAddress.of(token, key), packed); + key++; + bitSet = new BitSet(8); + } + } + } + return builder.build(); + } + + protected boolean verifyTensorType(TensorType target) { + return target.dimensions().size() == 2 && + target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; + } + + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { + int size = input.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; ++i) { + builder.cell(input.get(i), i); + } + return builder.build(); + } + + private static final Set<Long> PUNCTUATION_TOKEN_IDS = new HashSet<>( + Arrays.asList(999L, 1000L, 1001L, 1002L, 1003L, 1004L, 1005L, 1006L, + 1007L, 1008L, 1009L, 1010L, 1011L, 1012L, 1013L, 1024L, + 1025L, 1026L, 1027L, 1028L, 1029L, 1030L, 1031L, 1032L, + 1033L, 1034L, 1035L, 1036L, 1063L, 1064L, 1065L, 1066L)); +} diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java new file mode 100644 index 00000000000..70f91eb44ad --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -0,0 +1,125 @@ +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.config.ModelReference; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MixedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; + +public class ColBertEmbedderTest { + + @Test + public void testPacking() { + assertPackedRight( + "" + + "tensor<float>(d1[6],d2[8]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1]," + + "[0, 0, 0, 0, 0, 1, 0, 1]," + + "[0, 0, 0, 0, 0, 0, 1, 1]," + + "[0, 1, 1, 1, 1, 1, 1, 1]," + + "[1, 0, 0, 0, 0, 0, 0, 0]," + + "[1, 1, 1, 1, 1, 1, 1, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[1])"), + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}" + ); + assertPackedRight( + "" + + "tensor<float>(d1[2],d2[16]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[2])"), + "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}" + ); + } + + @Test + public void testEmbedder() { + assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext); + assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext); + assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because int8 is not supported for query context + assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); + }); + assertThrows(IllegalArgumentException.class, () -> { + //throws because 16 is less than model output (128) and we want float + assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); + }); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because 128/8 does not fit into 15 + assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); + }); + } + + @Test + public void testLenghtLimits() { + StringBuilder sb = new StringBuilder(); + for(int i = 0; i < 1024; i++) { + sb.append("annoyance"); + sb.append(" "); + } + String text = sb.toString(); + Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + assertEquals(512*128,fullFloat.size()); + + Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); + assertEquals(32*128,query.size()); + + Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); + assertEquals(512*16,binaryRep.size()); + + Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); + // 4 tokens, 16 bytes each = 64 bytes + //because of CLS, special, sequence, SEP + assertEquals(4*16,shortDoc.size());; + } + + static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { + TensorType destType = TensorType.fromSpec(tensorSpec); + Tensor result = embedder.embed(text, context, destType); + assertEquals(destType,result.type()); + MixedTensor mixedTensor = (MixedTensor) result; + if(context == queryContext) { + assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); + } + return result; + } + + static void assertPackedRight(String numbers, TensorType destination,String expected) { + Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination); + assertEquals(expected,packed.toString()); + } + + static final Embedder embedder; + static final Embedder.Context indexingContext; + static final Embedder.Context queryContext; + static { + indexingContext = new Embedder.Context("schema.indexing"); + queryContext = new Embedder.Context("query(qt)"); + embedder = getEmbedder(); + } + private static Embedder getEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/tokenizer.json"; + String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); + ColBertEmbedderConfig.Builder builder = new ColBertEmbedderConfig.Builder(); + builder.tokenizerPath(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerGpuDevice(-1); + return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } +}
\ No newline at end of file diff --git a/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx Binary files differnew file mode 100644 index 00000000000..5ab1060e59e --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx diff --git a/model-integration/src/test/models/onnx/transformer/tokenizer.json b/model-integration/src/test/models/onnx/transformer/tokenizer.json new file mode 100644 index 00000000000..28340f289bb --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/tokenizer.json @@ -0,0 +1,175 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 100, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 101, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 102, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 103, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": { + "type": "BertNormalizer", + "clean_text": true, + "handle_chinese_chars": true, + "strip_accents": null, + "lowercase": true + }, + "pre_tokenizer": { + "type": "BertPreTokenizer" + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 1 + } + } + ], + "special_tokens": { + "[CLS]": { + "id": "[CLS]", + "ids": [101], + "tokens": ["[CLS]"] + }, + "[SEP]": { + "id": "[SEP]", + "ids": [102], + "tokens": ["[SEP]"] + } + } + }, + "decoder": { + "type": "WordPiece", + "prefix": "##", + "cleanup": true + }, + "model": { + "type": "WordPiece", + "unk_token": "[UNK]", + "continuing_subword_prefix": "##", + "max_input_chars_per_word": 100, + "vocab": { + "[PAD]": 0, + "[unused0]": 1, + "[unused1]": 2, + "[UNK]": 100, + "[CLS]": 101, + "[SEP]": 102, + "[MASK]": 103, + "a": 1037, + "b": 1038, + "c": 1039, + "d": 1040, + "e": 1041, + "f": 1042, + "g": 1043, + "h": 1044, + "i": 1045, + "j": 1046, + "k": 1047, + "l": 1048, + "m": 1049, + "n": 1050, + "o": 1051, + "p": 1052, + "q": 1053, + "r": 1054, + "s": 1055, + "t": 1056, + "u": 1057, + "v": 1058, + "w": 1059, + "x": 1060, + "y": 1061, + "z": 1062 + } + } +} |