aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-09-26 14:14:18 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2023-09-26 14:14:18 +0200
commit01deefc0c007995573c5564be7aa4d0ce1e01203 (patch)
treeb4d009b496e5f14b91f0c7f221a378e3ca916bed /model-integration/src
parent4231e6077a18b6fdf96ac899a7301882ef50d742 (diff)
Don't index PAD and re-factoring
Diffstat (limited to 'model-integration/src')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java60
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java18
2 files changed, 37 insertions, 41 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index aafb9877c27..4bb7bcc9225 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -27,7 +27,7 @@ import java.util.Arrays;
import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
/**
- * A ColBERT embedder implementation that maps text to multiple vectors, one vector per subword id.
+ * A ColBERT embedder implementation that maps text to multiple vectors, one vector per token subword id.
* This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model.
*
* See col-bert-embedder.def for configurable parameters.
@@ -60,10 +60,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
maxTransformerTokens = config.transformerMaxTokens();
- if(config.maxDocumentTokens() > maxTransformerTokens)
- throw new IllegalArgumentException("maxDocumentTokens must be less than or equal to transformerMaxTokens");
- maxDocumentTokens = config.maxDocumentTokens();
- maxQueryTokens = config.maxQueryTokens();
+ maxDocumentTokens = Math.min(config.maxDocumentTokens(), maxTransformerTokens);
+ maxQueryTokens = Math.min(config.maxQueryTokens(), maxTransformerTokens);
startSequenceToken = config.transformerStartSequenceToken();
endSequenceToken = config.transformerEndSequenceToken();
maskSequenceToken = config.transformerMaskToken();
@@ -75,7 +73,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
.setPadding(false);
var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath);
if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) {
- // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration
+ // Force truncation
+ // to max length accepted by model if tokenizer.json contains no valid truncation configuration
int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens()
? info.maxLength()
: config.transformerMaxTokens();
@@ -115,8 +114,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
if(!verifyTensorType(tensorType)) {
- throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination." +
- "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType.toString());
+ throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " +
+ "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
}
if (context.getDestination().startsWith("query")) {
return embedQuery(text, context, tensorType);
@@ -152,6 +151,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
inputIds.add(Q_TOKEN_ID);
inputIds.addAll(ids);
inputIds.add(endSequenceToken);
+
int length = inputIds.size();
int padding = maxQueryTokens - length;
@@ -177,12 +177,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
throw new IllegalArgumentException("Token dimensionality does not" +
" match indexed dimensionality of " + dims);
}
- Tensor.Builder builder = Tensor.Builder.of(tensorType);
- for (int token = 0; token < result.shape()[0]; token++)
- for (int d = 0; d < result.shape()[1]; d++)
- builder.cell(TensorAddress.of(token, d), result.get(TensorAddress.of(token, d)));
+ Tensor resultTensor = toFloatTensor(result, tensorType, inputIds.size());
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return builder.build();
+ return resultTensor;
}
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
@@ -193,7 +190,6 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
List<Long> ids = encoding.ids().stream().filter(token
-> !PUNCTUATION_TOKEN_IDS.contains(token)).toList();
- ;
if (ids.size() > maxDocumentTokens - 3)
ids = ids.subList(0, maxDocumentTokens - 3);
@@ -216,29 +212,29 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Tensor tokenEmbeddings = outputs.get(outputName);
IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
Tensor contextualEmbeddings;
+ int retainedTokens = inputIds.size() -1; //Do not retain last PAD
if(tensorType.valueType() == TensorType.Value.INT8) {
- contextualEmbeddings = toBitTensor(result, tensorType);
+ contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens);
} else {
- contextualEmbeddings = toFloatTensor(result, tensorType);
+ contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens);
}
-
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return contextualEmbeddings;
}
- public static Tensor toFloatTensor(IndexedTensor result, TensorType type) {
+ public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
throw new IllegalArgumentException("Indexed tensor must have one dimension");
- int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue();
- int resultDim = (int)result.shape()[1];
- if(resultDim != dims) {
- throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDim
- + " + dimensions into tensor with " + dims);
+ int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
+ int resultDimensionality = (int)result.shape()[1];
+ if(resultDimensionality != wantedDimensionality) {
+ throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality
+ + " + dimensions into tensor with " + wantedDimensionality);
}
Tensor.Builder builder = Tensor.Builder.of(type);
- for (int token = 0; token < result.shape()[0]; token++) {
- for (int d = 0; d < result.shape()[1]; d++) {
+ for (int token = 0; token < nTokens; token++) {
+ for (int d = 0; d < resultDimensionality; d++) {
var value = result.get(TensorAddress.of(token, d));
builder.cell(TensorAddress.of(token,d),value);
}
@@ -246,21 +242,21 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- public static Tensor toBitTensor(IndexedTensor result, TensorType type) {
+ public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) {
if (type.valueType() != TensorType.Value.INT8)
throw new IllegalArgumentException("Only a int8 tensor type can be" +
" the destination of bit packing");
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
throw new IllegalArgumentException("Indexed tensor must have one dimension");
- int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue();
- int resultDim = (int)result.shape()[1];
- if(resultDim/8 != dims) {
- throw new IllegalArgumentException("Not possible to pack " + resultDim
- + " + dimensions into " + dims);
+ int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
+ int resultDimensionality = (int)result.shape()[1];
+ if(resultDimensionality/8 != wantedDimensionality) {
+ throw new IllegalArgumentException("Not possible to pack " + resultDimensionality
+ + " + dimensions into " + wantedDimensionality + " dimensions");
}
Tensor.Builder builder = Tensor.Builder.of(type);
- for (int token = 0; token < result.shape()[0]; token++) {
+ for (int token = 0; token < nTokens; token++) {
BitSet bitSet = new BitSet(8);
int key = 0;
for (int d = 0; d < result.shape()[1]; d++) {
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 8516f6e6689..4e398f7245d 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -31,7 +31,7 @@ public class ColBertEmbedderTest {
"[1, 1, 1, 1, 1, 1, 1, 1]" +
"]",
TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
- "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}"
+ "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6
);
assertPackedRight(
"" +
@@ -41,7 +41,7 @@ public class ColBertEmbedderTest {
"[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
"]",
TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
- "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}"
+ "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
);
}
@@ -75,18 +75,18 @@ public class ColBertEmbedderTest {
}
String text = sb.toString();
Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
- assertEquals(512*128,fullFloat.size());
+ assertEquals(511*128,fullFloat.size());
Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
assertEquals(32*128,query.size());
Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
- assertEquals(512*16,binaryRep.size());
+ assertEquals(511*16,binaryRep.size());
Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
- // 4 tokens, 16 bytes each = 64 bytes
- //because of CLS, special, sequence, SEP
- assertEquals(4*16,shortDoc.size());;
+ // 3 tokens, 16 bytes each = 48 bytes
+ //CLS [unused1] sequence
+ assertEquals(3*16,shortDoc.size());;
}
static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
@@ -100,8 +100,8 @@ public class ColBertEmbedderTest {
return result;
}
- static void assertPackedRight(String numbers, TensorType destination,String expected) {
- Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination);
+ static void assertPackedRight(String numbers, TensorType destination,String expected, int size) {
+ Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size);
assertEquals(expected,packed.toString());
}