diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-06 10:54:58 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-06 10:54:58 +0100 |
commit | 18ae21bce56e018cef2c17d03e63617530af59ae (patch) | |
tree | 3c1dcee63395fee2e476be9ce33e2437262b00d7 /model-integration/src/main/java/ai/vespa | |
parent | e4da75db4556a3cd72b034c4406027f9bba73918 (diff) |
handle multilingual models better
Diffstat (limited to 'model-integration/src/main/java/ai/vespa')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 122 |
1 files changed, 62 insertions, 60 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index d42ec629bf7..faba1bfac4c 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -26,7 +26,6 @@ import java.util.ArrayList; import java.util.Set; import java.util.HashSet; import java.util.BitSet; -import java.util.Arrays; import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; @@ -42,12 +41,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { private final Embedder.Runtime runtime; private final String inputIdsName; private final String attentionMaskName; - private final String outputName; - private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; - private final int maxTransformerTokens; private final int maxQueryTokens; private final int maxDocumentTokens; @@ -56,6 +52,14 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { private final long endSequenceToken; private final long maskSequenceToken; + private final long padSequenceToken; + + private final long querySequenceToken; + + private final long documentSequenceToken; + private Set<Long> skipTokens; + + public record TransformerInput(List<Long> inputIds, List<Long> attentionMask) {} @Inject public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) { @@ -69,6 +73,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { startSequenceToken = config.transformerStartSequenceToken(); endSequenceToken = config.transformerEndSequenceToken(); maskSequenceToken = config.transformerMaskToken(); + padSequenceToken = config.transformerPadToken(); + querySequenceToken = config.queryTokenId(); + documentSequenceToken = config.documentTokenId(); var tokenizerPath = Paths.get(config.tokenizerPath().toString()); var builder = new HuggingFaceTokenizer.Builder() @@ -85,8 +92,12 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { builder.setTruncation(true).setMaxLength(maxLength); } this.tokenizer = builder.build(); + this.skipTokens = new HashSet<>(); + PUNCTUATION.chars().forEach( + c -> this.skipTokens.addAll( + tokenizer.encode(Character.toString((char) c), null).ids()) + ); var onnxOpts = new OnnxEvaluatorOptions(); - if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); onnxOpts.setExecutionMode(config.transformerExecutionMode().toString()); @@ -118,7 +129,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { if (!verifyTensorType(tensorType)) { - throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " + + throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. " + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } if (context.getDestination().startsWith("query")) { @@ -127,48 +138,54 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return embedDocument(text, context, tensorType); } } - @Override public void deconstruct() { evaluator.close(); tokenizer.close(); } - protected Tensor embedQuery(String text, Context context, TensorType tensorType) { - if (tensorType.valueType() == TensorType.Value.INT8) - throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); - - long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document. - - var start = System.nanoTime(); - var encoding = tokenizer.encode(text, context.getLanguage()); - runtime.sampleSequenceLength(encoding.ids().size(), context); - - List<Long> ids = encoding.ids(); - if (ids.size() > maxQueryTokens - 3) - ids = ids.subList(0, maxQueryTokens - 3); - - List<Long> inputIds = new ArrayList<>(maxQueryTokens); - List<Long> attentionMask = new ArrayList<>(maxQueryTokens); - + protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) { + if(!isQuery) { + tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList(); + } + List<Long> inputIds = new ArrayList<>(maxTokens); + List<Long> attentionMask = new ArrayList<>(maxTokens); + if (tokens.size() > maxTokens - 3) + tokens = tokens.subList(0, maxTokens - 3); inputIds.add(startSequenceToken); - inputIds.add(Q_TOKEN_ID); - inputIds.addAll(ids); + inputIds.add(isQuery? querySequenceToken: documentSequenceToken); + inputIds.addAll(tokens); inputIds.add(endSequenceToken); - int length = inputIds.size(); + int inputLength = inputIds.size(); + long padTokenId = isQuery? maskSequenceToken: padSequenceToken; - int padding = maxQueryTokens - length; + int padding = isQuery? maxTokens - inputLength: 0; for (int i = 0; i < padding; i++) - inputIds.add(maskSequenceToken); + inputIds.add(padTokenId); - for (int i = 0; i < length; i++) + for (int i = 0; i < inputLength; i++) attentionMask.add((long) 1); + for (int i = 0; i < padding; i++) attentionMask.add((long) 0);//Do not attend to mask paddings - Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); - Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + return new TransformerInput(inputIds, attentionMask); + } + + protected Tensor embedQuery(String text, Context context, TensorType tensorType) { + if (tensorType.valueType() == TensorType.Value.INT8) + throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); + + + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true); + + Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); @@ -178,36 +195,22 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); if (dims != result.shape()[1]) { - throw new IllegalArgumentException("Token dimensionality does not" + + throw new IllegalArgumentException("Token vector dimensionality does not" + " match indexed dimensionality of " + dims); } - Tensor resultTensor = toFloatTensor(result, tensorType, inputIds.size()); + Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return resultTensor; } protected Tensor embedDocument(String text, Context context, TensorType tensorType) { - long D_TOKEN_ID = 2; // [unused1] token id used during training to differentiate query versus document. var start = System.nanoTime(); var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); - List<Long> ids = encoding.ids().stream().filter(token - -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList(); - - if (ids.size() > maxDocumentTokens - 3) - ids = ids.subList(0, maxDocumentTokens - 3); - List<Long> inputIds = new ArrayList<>(maxDocumentTokens); - List<Long> attentionMask = new ArrayList<>(maxDocumentTokens); - inputIds.add(startSequenceToken); - inputIds.add(D_TOKEN_ID); - inputIds.addAll(ids); - inputIds.add(endSequenceToken); - for (int i = 0; i < inputIds.size(); i++) - attentionMask.add((long) 1); - - Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); - Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false); + Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); @@ -216,11 +219,11 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor tokenEmbeddings = outputs.get(outputName); IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); Tensor contextualEmbeddings; - int retainedTokens = inputIds.size() -1; //Do not retain last PAD + int maxTokens = input.inputIds.size() -1; //Do not retain last PAD if (tensorType.valueType() == TensorType.Value.INT8) { - contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens); + contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); } else { - contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens); + contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens); } runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return contextualEmbeddings; @@ -283,6 +286,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + public Set<Long> getSkipTokens() { + return this.skipTokens; + } + public static Tensor expandBitTensor(Tensor packed) { var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.FLOAT, "big"); var context = new MapContext(); @@ -304,10 +311,5 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } - - private static final Set<Long> PUNCTUATION_TOKEN_IDS = new HashSet<>( - Arrays.asList(999L, 1000L, 1001L, 1002L, 1003L, 1004L, 1005L, 1006L, - 1007L, 1008L, 1009L, 1010L, 1011L, 1012L, 1013L, 1024L, - 1025L, 1026L, 1027L, 1028L, 1029L, 1030L, 1031L, 1032L, - 1033L, 1034L, 1035L, 1036L, 1063L, 1064L, 1065L, 1066L)); + private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; } |