aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-01-06 10:54:58 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-01-06 10:54:58 +0100
commit18ae21bce56e018cef2c17d03e63617530af59ae (patch)
tree3c1dcee63395fee2e476be9ce33e2437262b00d7 /model-integration/src/main/java/ai/vespa
parente4da75db4556a3cd72b034c4406027f9bba73918 (diff)
handle multilingual models better
Diffstat (limited to 'model-integration/src/main/java/ai/vespa')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java122
1 files changed, 62 insertions, 60 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index d42ec629bf7..faba1bfac4c 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -26,7 +26,6 @@ import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.BitSet;
-import java.util.Arrays;
import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
@@ -42,12 +41,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
private final Embedder.Runtime runtime;
private final String inputIdsName;
private final String attentionMaskName;
-
private final String outputName;
-
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
-
private final int maxTransformerTokens;
private final int maxQueryTokens;
private final int maxDocumentTokens;
@@ -56,6 +52,14 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
private final long endSequenceToken;
private final long maskSequenceToken;
+ private final long padSequenceToken;
+
+ private final long querySequenceToken;
+
+ private final long documentSequenceToken;
+ private Set<Long> skipTokens;
+
+ public record TransformerInput(List<Long> inputIds, List<Long> attentionMask) {}
@Inject
public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) {
@@ -69,6 +73,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
startSequenceToken = config.transformerStartSequenceToken();
endSequenceToken = config.transformerEndSequenceToken();
maskSequenceToken = config.transformerMaskToken();
+ padSequenceToken = config.transformerPadToken();
+ querySequenceToken = config.queryTokenId();
+ documentSequenceToken = config.documentTokenId();
var tokenizerPath = Paths.get(config.tokenizerPath().toString());
var builder = new HuggingFaceTokenizer.Builder()
@@ -85,8 +92,12 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
builder.setTruncation(true).setMaxLength(maxLength);
}
this.tokenizer = builder.build();
+ this.skipTokens = new HashSet<>();
+ PUNCTUATION.chars().forEach(
+ c -> this.skipTokens.addAll(
+ tokenizer.encode(Character.toString((char) c), null).ids())
+ );
var onnxOpts = new OnnxEvaluatorOptions();
-
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
@@ -118,7 +129,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
if (!verifyTensorType(tensorType)) {
- throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " +
+ throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. " +
"Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
}
if (context.getDestination().startsWith("query")) {
@@ -127,48 +138,54 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return embedDocument(text, context, tensorType);
}
}
-
@Override
public void deconstruct() {
evaluator.close();
tokenizer.close();
}
- protected Tensor embedQuery(String text, Context context, TensorType tensorType) {
- if (tensorType.valueType() == TensorType.Value.INT8)
- throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
-
- long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document.
-
- var start = System.nanoTime();
- var encoding = tokenizer.encode(text, context.getLanguage());
- runtime.sampleSequenceLength(encoding.ids().size(), context);
-
- List<Long> ids = encoding.ids();
- if (ids.size() > maxQueryTokens - 3)
- ids = ids.subList(0, maxQueryTokens - 3);
-
- List<Long> inputIds = new ArrayList<>(maxQueryTokens);
- List<Long> attentionMask = new ArrayList<>(maxQueryTokens);
-
+ protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) {
+ if(!isQuery) {
+ tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList();
+ }
+ List<Long> inputIds = new ArrayList<>(maxTokens);
+ List<Long> attentionMask = new ArrayList<>(maxTokens);
+ if (tokens.size() > maxTokens - 3)
+ tokens = tokens.subList(0, maxTokens - 3);
inputIds.add(startSequenceToken);
- inputIds.add(Q_TOKEN_ID);
- inputIds.addAll(ids);
+ inputIds.add(isQuery? querySequenceToken: documentSequenceToken);
+ inputIds.addAll(tokens);
inputIds.add(endSequenceToken);
- int length = inputIds.size();
+ int inputLength = inputIds.size();
+ long padTokenId = isQuery? maskSequenceToken: padSequenceToken;
- int padding = maxQueryTokens - length;
+ int padding = isQuery? maxTokens - inputLength: 0;
for (int i = 0; i < padding; i++)
- inputIds.add(maskSequenceToken);
+ inputIds.add(padTokenId);
- for (int i = 0; i < length; i++)
+ for (int i = 0; i < inputLength; i++)
attentionMask.add((long) 1);
+
for (int i = 0; i < padding; i++)
attentionMask.add((long) 0);//Do not attend to mask paddings
- Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1");
- Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1");
+ return new TransformerInput(inputIds, attentionMask);
+ }
+
+ protected Tensor embedQuery(String text, Context context, TensorType tensorType) {
+ if (tensorType.valueType() == TensorType.Value.INT8)
+ throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
+
+
+ var start = System.nanoTime();
+ var encoding = tokenizer.encode(text, context.getLanguage());
+ runtime.sampleSequenceLength(encoding.ids().size(), context);
+
+ TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true);
+
+ Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
+ Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
@@ -178,36 +195,22 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
if (dims != result.shape()[1]) {
- throw new IllegalArgumentException("Token dimensionality does not" +
+ throw new IllegalArgumentException("Token vector dimensionality does not" +
" match indexed dimensionality of " + dims);
}
- Tensor resultTensor = toFloatTensor(result, tensorType, inputIds.size());
+ Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size());
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
- long D_TOKEN_ID = 2; // [unused1] token id used during training to differentiate query versus document.
var start = System.nanoTime();
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
- List<Long> ids = encoding.ids().stream().filter(token
- -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList();
-
- if (ids.size() > maxDocumentTokens - 3)
- ids = ids.subList(0, maxDocumentTokens - 3);
- List<Long> inputIds = new ArrayList<>(maxDocumentTokens);
- List<Long> attentionMask = new ArrayList<>(maxDocumentTokens);
- inputIds.add(startSequenceToken);
- inputIds.add(D_TOKEN_ID);
- inputIds.addAll(ids);
- inputIds.add(endSequenceToken);
- for (int i = 0; i < inputIds.size(); i++)
- attentionMask.add((long) 1);
-
- Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1");
- Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1");
+ TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false);
+ Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
+ Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
@@ -216,11 +219,11 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Tensor tokenEmbeddings = outputs.get(outputName);
IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
Tensor contextualEmbeddings;
- int retainedTokens = inputIds.size() -1; //Do not retain last PAD
+ int maxTokens = input.inputIds.size() -1; //Do not retain last PAD
if (tensorType.valueType() == TensorType.Value.INT8) {
- contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens);
+ contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
} else {
- contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens);
+ contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens);
}
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return contextualEmbeddings;
@@ -283,6 +286,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ public Set<Long> getSkipTokens() {
+ return this.skipTokens;
+ }
+
public static Tensor expandBitTensor(Tensor packed) {
var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.FLOAT, "big");
var context = new MapContext();
@@ -304,10 +311,5 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
-
- private static final Set<Long> PUNCTUATION_TOKEN_IDS = new HashSet<>(
- Arrays.asList(999L, 1000L, 1001L, 1002L, 1003L, 1004L, 1005L, 1006L,
- 1007L, 1008L, 1009L, 1010L, 1011L, 1012L, 1013L, 1024L,
- 1025L, 1026L, 1027L, 1028L, 1029L, 1030L, 1031L, 1032L,
- 1033L, 1034L, 1035L, 1036L, 1063L, 1064L, 1065L, 1066L));
+ private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
}