diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-09-26 14:14:18 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-09-26 14:14:18 +0200 |
commit | 01deefc0c007995573c5564be7aa4d0ce1e01203 (patch) | |
tree | b4d009b496e5f14b91f0c7f221a378e3ca916bed | |
parent | 4231e6077a18b6fdf96ac899a7301882ef50d742 (diff) |
Don't index PAD and re-factoring
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 60 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 18 |
2 files changed, 37 insertions, 41 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index aafb9877c27..4bb7bcc9225 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -27,7 +27,7 @@ import java.util.Arrays; import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; /** - * A ColBERT embedder implementation that maps text to multiple vectors, one vector per subword id. + * A ColBERT embedder implementation that maps text to multiple vectors, one vector per token subword id. * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. * * See col-bert-embedder.def for configurable parameters. @@ -60,10 +60,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); maxTransformerTokens = config.transformerMaxTokens(); - if(config.maxDocumentTokens() > maxTransformerTokens) - throw new IllegalArgumentException("maxDocumentTokens must be less than or equal to transformerMaxTokens"); - maxDocumentTokens = config.maxDocumentTokens(); - maxQueryTokens = config.maxQueryTokens(); + maxDocumentTokens = Math.min(config.maxDocumentTokens(), maxTransformerTokens); + maxQueryTokens = Math.min(config.maxQueryTokens(), maxTransformerTokens); startSequenceToken = config.transformerStartSequenceToken(); endSequenceToken = config.transformerEndSequenceToken(); maskSequenceToken = config.transformerMaskToken(); @@ -75,7 +73,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { .setPadding(false); var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { - // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + // Force truncation + // to max length accepted by model if tokenizer.json contains no valid truncation configuration int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() ? info.maxLength() : config.transformerMaxTokens(); @@ -115,8 +114,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { if(!verifyTensorType(tensorType)) { - throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination." + - "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType.toString()); + throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " + + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } if (context.getDestination().startsWith("query")) { return embedQuery(text, context, tensorType); @@ -152,6 +151,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { inputIds.add(Q_TOKEN_ID); inputIds.addAll(ids); inputIds.add(endSequenceToken); + int length = inputIds.size(); int padding = maxQueryTokens - length; @@ -177,12 +177,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Token dimensionality does not" + " match indexed dimensionality of " + dims); } - Tensor.Builder builder = Tensor.Builder.of(tensorType); - for (int token = 0; token < result.shape()[0]; token++) - for (int d = 0; d < result.shape()[1]; d++) - builder.cell(TensorAddress.of(token, d), result.get(TensorAddress.of(token, d))); + Tensor resultTensor = toFloatTensor(result, tensorType, inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return builder.build(); + return resultTensor; } protected Tensor embedDocument(String text, Context context, TensorType tensorType) { @@ -193,7 +190,6 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { List<Long> ids = encoding.ids().stream().filter(token -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList(); - ; if (ids.size() > maxDocumentTokens - 3) ids = ids.subList(0, maxDocumentTokens - 3); @@ -216,29 +212,29 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor tokenEmbeddings = outputs.get(outputName); IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); Tensor contextualEmbeddings; + int retainedTokens = inputIds.size() -1; //Do not retain last PAD if(tensorType.valueType() == TensorType.Value.INT8) { - contextualEmbeddings = toBitTensor(result, tensorType); + contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens); } else { - contextualEmbeddings = toFloatTensor(result, tensorType); + contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens); } - runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return contextualEmbeddings; } - public static Tensor toFloatTensor(IndexedTensor result, TensorType type) { + public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { int size = type.indexedSubtype().dimensions().size(); if (size != 1) throw new IllegalArgumentException("Indexed tensor must have one dimension"); - int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDim = (int)result.shape()[1]; - if(resultDim != dims) { - throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDim - + " + dimensions into tensor with " + dims); + int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDimensionality = (int)result.shape()[1]; + if(resultDimensionality != wantedDimensionality) { + throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + + " + dimensions into tensor with " + wantedDimensionality); } Tensor.Builder builder = Tensor.Builder.of(type); - for (int token = 0; token < result.shape()[0]; token++) { - for (int d = 0; d < result.shape()[1]; d++) { + for (int token = 0; token < nTokens; token++) { + for (int d = 0; d < resultDimensionality; d++) { var value = result.get(TensorAddress.of(token, d)); builder.cell(TensorAddress.of(token,d),value); } @@ -246,21 +242,21 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - public static Tensor toBitTensor(IndexedTensor result, TensorType type) { + public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) { if (type.valueType() != TensorType.Value.INT8) throw new IllegalArgumentException("Only a int8 tensor type can be" + " the destination of bit packing"); int size = type.indexedSubtype().dimensions().size(); if (size != 1) throw new IllegalArgumentException("Indexed tensor must have one dimension"); - int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDim = (int)result.shape()[1]; - if(resultDim/8 != dims) { - throw new IllegalArgumentException("Not possible to pack " + resultDim - + " + dimensions into " + dims); + int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDimensionality = (int)result.shape()[1]; + if(resultDimensionality/8 != wantedDimensionality) { + throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + + " + dimensions into " + wantedDimensionality + " dimensions"); } Tensor.Builder builder = Tensor.Builder.of(type); - for (int token = 0; token < result.shape()[0]; token++) { + for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; for (int d = 0; d < result.shape()[1]; d++) { diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index 8516f6e6689..4e398f7245d 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -31,7 +31,7 @@ public class ColBertEmbedderTest { "[1, 1, 1, 1, 1, 1, 1, 1]" + "]", TensorType.fromSpec("tensor<int8>(dt{},x[1])"), - "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}" + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6 ); assertPackedRight( "" + @@ -41,7 +41,7 @@ public class ColBertEmbedderTest { "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + "]", TensorType.fromSpec("tensor<int8>(dt{},x[2])"), - "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}" + "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2 ); } @@ -75,18 +75,18 @@ public class ColBertEmbedderTest { } String text = sb.toString(); Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); - assertEquals(512*128,fullFloat.size()); + assertEquals(511*128,fullFloat.size()); Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); assertEquals(32*128,query.size()); Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); - assertEquals(512*16,binaryRep.size()); + assertEquals(511*16,binaryRep.size()); Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); - // 4 tokens, 16 bytes each = 64 bytes - //because of CLS, special, sequence, SEP - assertEquals(4*16,shortDoc.size());; + // 3 tokens, 16 bytes each = 48 bytes + //CLS [unused1] sequence + assertEquals(3*16,shortDoc.size());; } static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { @@ -100,8 +100,8 @@ public class ColBertEmbedderTest { return result; } - static void assertPackedRight(String numbers, TensorType destination,String expected) { - Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination); + static void assertPackedRight(String numbers, TensorType destination,String expected, int size) { + Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size); assertEquals(expected,packed.toString()); } |