summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-02-02 12:28:53 +0100
committerJon Bratseth <bratseth@vespa.ai>2024-02-02 12:28:53 +0100
commit1a25431ab58c752c7fc26dd8223bf1ba1079b24a (patch)
tree954d7e2f3e43bb0636a6af7a93195a84e41e609b /model-integration
parent2191193c6e107eb68611ddb106e5f572bea32903 (diff)
Support embedding into rank 3 tensors
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java34
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java12
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java25
3 files changed, 42 insertions, 29 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 8c39cc8c813..f76bfd28abf 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -18,7 +18,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
+
import java.nio.file.Paths;
import java.util.Map;
import java.util.List;
@@ -34,10 +34,14 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES
* This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model.
*
* See col-bert-embedder.def for configurable parameters.
+ *
* @author bergum
*/
@Beta
public class ColBertEmbedder extends AbstractComponent implements Embedder {
+
+ private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
+
private final Embedder.Runtime runtime;
private final String inputIdsName;
private final String attentionMaskName;
@@ -117,7 +121,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
private void validateName(Map<String, TensorType> types, String name, String type) {
if (!types.containsKey(name)) {
throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
- "Model contains: " + String.join(",", types.keySet()));
+ "Model contains: " + String.join(",", types.keySet()));
}
}
@@ -128,9 +132,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
- if (!verifyTensorType(tensorType)) {
+ if ( ! validTensorType(tensorType)) {
throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. " +
- "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
+ "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
}
if (context.getDestination().startsWith("query")) {
return embedQuery(text, context, tensorType);
@@ -196,7 +200,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
if (dims != result.shape()[2]) {
throw new IllegalArgumentException("Token vector dimensionality does not" +
- " match indexed dimensionality of " + dims);
+ " match indexed dimensionality of " + dims);
}
Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size());
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
@@ -213,13 +217,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
- attentionMaskName, attentionMaskTensor.expand("d0"));
+ attentionMaskName, attentionMaskTensor.expand("d0"));
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
IndexedTensor result = (IndexedTensor) tokenEmbeddings;
Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens.
+ int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens.
if (tensorType.valueType() == TensorType.Value.INT8) {
contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
} else {
@@ -230,7 +234,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
- if(result.shape().length != 3)
+ if (result.shape().length != 3)
throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
@@ -253,8 +257,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) {
if (type.valueType() != TensorType.Value.INT8)
- throw new IllegalArgumentException("Only a int8 tensor type can be" +
- " the destination of bit packing");
+ throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing");
if(result.shape().length != 3)
throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
@@ -264,8 +267,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[2];
if (resultDimensionality != 8 * wantedDimensionality) {
- throw new IllegalArgumentException("Not possible to pack " + resultDimensionality
- + " + dimensions into " + wantedDimensionality + " dimensions");
+ throw new IllegalArgumentException("Not possible to pack " + resultDimensionality +
+ " + dimensions into " + wantedDimensionality + " dimensions");
}
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
@@ -302,9 +305,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return unpacker.evaluate(context).asTensor();
}
- protected boolean verifyTensorType(TensorType target) {
- return target.dimensions().size() == 2 &&
- target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1;
+ protected boolean validTensorType(TensorType target) {
+ return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1;
}
private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
@@ -316,5 +318,5 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
- private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
+
}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 3a64083c623..58bd4deb659 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -25,9 +25,12 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES
/**
* A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels
* are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0).
+ *
+ * @author bergum
*/
@Beta
public class SpladeEmbedder extends AbstractComponent implements Embedder {
+
private final Embedder.Runtime runtime;
private final String inputIdsName;
private final String attentionMaskName;
@@ -110,7 +113,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
public Tensor embed(String text, Context context, TensorType tensorType) {
if (!verifyTensorType(tensorType)) {
throw new IllegalArgumentException("Invalid splade embedder tensor destination. " +
- "Wanted a mapped 1-d tensor, got " + tensorType);
+ "Wanted a mapped 1-d tensor, got " + tensorType);
}
var start = System.nanoTime();
@@ -132,17 +135,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
return spladeTensor;
}
-
/**
* Sparsify the output tensor by applying a threshold on the log of the relu of the output.
* This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant.
+ *
* @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
- * of the vocabulary
+ * of the vocabulary
* @param tensorType the type of the destination tensor
* @return A mapped tensor with the terms from the vocab that has a score above the threshold
*/
private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) {
- //Remove batch dim, batch size of 1
+ // Remove batch dim, batch size of 1
Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1");
Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0)));
IndexedTensor vocab = (IndexedTensor) logOfRelu;
@@ -227,6 +230,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
+
@Override
public void deconstruct() {
evaluator.close();
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 0cae94c372a..be75c4d3351 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -19,6 +19,9 @@ import java.util.Set;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
+/**
+ * @author bergum
+ */
public class ColBertEmbedderTest {
@Test
@@ -67,23 +70,24 @@ public class ColBertEmbedderTest {
assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext);
assertThrows(IllegalArgumentException.class, () -> {
- //throws because int8 is not supported for query context
+ // throws because int8 is not supported for query context
assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext);
});
+
assertThrows(IllegalArgumentException.class, () -> {
- //throws because 16 is less than model output (128) and we want float
+ // throws because 16 is less than model output (128) and we want float
assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext);
});
assertThrows(IllegalArgumentException.class, () -> {
- //throws because 128/8 does not fit into 15
+ // throws because 128/8 does not fit into 15
assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext);
});
}
@Test
public void testInputTensorsWordPiece() {
- //wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999]
+ // wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999]
List<Long> tokens = List.of(2023L, 2003L, 1037L, 23032L, 999L);
ColBertEmbedder.TransformerInput input = embedder.buildTransformerInput(tokens,10,true);
assertEquals(10,input.inputIds().size());
@@ -100,7 +104,7 @@ public class ColBertEmbedderTest {
@Test
public void testInputTensorsSentencePiece() {
- //Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711]
+ // Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711]
// ! is mapped to 711 and is a punctuation character
List<Long> tokens = List.of(903L, 83L, 10L, 41L, 1294L, 711L);
ColBertEmbedder.TransformerInput input = multiLingualEmbedder.buildTransformerInput(tokens,10,true);
@@ -109,7 +113,7 @@ public class ColBertEmbedderTest {
assertEquals(List.of(0L, 3L, 903L, 83L, 10L, 41L, 1294L, 711L, 2L, 250001L),input.inputIds());
assertEquals(List.of(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 0L),input.attentionMask());
- //NO padding for document side and 711 (punctuation) is now filtered out
+ // NO padding for document side and 711 (punctuation) is now filtered out
input = multiLingualEmbedder.buildTransformerInput(tokens,10,false);
assertEquals(8,input.inputIds().size());
assertEquals(8,input.attentionMask().size());
@@ -156,12 +160,12 @@ public class ColBertEmbedderTest {
sb.append(" ");
}
String text = sb.toString();
- Long now = System.currentTimeMillis();
+ long now = System.currentTimeMillis();
int n = 1000;
for (int i = 0; i < n; i++) {
assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
}
- Long elapsed = (System.currentTimeMillis() - now);
+ long elapsed = (System.currentTimeMillis() - now);
System.out.println("Elapsed time: " + elapsed + " ms");
}
@@ -170,7 +174,7 @@ public class ColBertEmbedderTest {
Tensor result = embedder.embed(text, context, destType);
assertEquals(destType,result.type());
MixedTensor mixedTensor = (MixedTensor) result;
- if(context == queryContext) {
+ if (context == queryContext) {
assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size());
}
return result;
@@ -200,12 +204,14 @@ public class ColBertEmbedderTest {
static final ColBertEmbedder multiLingualEmbedder;
static final Embedder.Context indexingContext;
static final Embedder.Context queryContext;
+
static {
indexingContext = new Embedder.Context("schema.indexing");
queryContext = new Embedder.Context("query(qt)");
embedder = getEmbedder();
multiLingualEmbedder = getMultiLingualEmbedder();
}
+
private static ColBertEmbedder getEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
@@ -235,4 +241,5 @@ public class ColBertEmbedderTest {
return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
}
+
}