diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2024-02-02 12:28:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2024-02-02 12:28:53 +0100 |
commit | 1a25431ab58c752c7fc26dd8223bf1ba1079b24a (patch) | |
tree | 954d7e2f3e43bb0636a6af7a93195a84e41e609b /model-integration | |
parent | 2191193c6e107eb68611ddb106e5f572bea32903 (diff) |
Support embedding into rank 3 tensors
Diffstat (limited to 'model-integration')
3 files changed, 42 insertions, 29 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 8c39cc8c813..f76bfd28abf 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -18,7 +18,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; + import java.nio.file.Paths; import java.util.Map; import java.util.List; @@ -34,10 +34,14 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. * * See col-bert-embedder.def for configurable parameters. + * * @author bergum */ @Beta public class ColBertEmbedder extends AbstractComponent implements Embedder { + + private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; + private final Embedder.Runtime runtime; private final String inputIdsName; private final String attentionMaskName; @@ -117,7 +121,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { private void validateName(Map<String, TensorType> types, String name, String type) { if (!types.containsKey(name)) { throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + - "Model contains: " + String.join(",", types.keySet())); + "Model contains: " + String.join(",", types.keySet())); } } @@ -128,9 +132,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { - if (!verifyTensorType(tensorType)) { + if ( ! validTensorType(tensorType)) { throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. " + - "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } if (context.getDestination().startsWith("query")) { return embedQuery(text, context, tensorType); @@ -196,7 +200,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); if (dims != result.shape()[2]) { throw new IllegalArgumentException("Token vector dimensionality does not" + - " match indexed dimensionality of " + dims); + " match indexed dimensionality of " + dims); } Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); @@ -213,13 +217,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), - attentionMaskName, attentionMaskTensor.expand("d0")); + attentionMaskName, attentionMaskTensor.expand("d0")); Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); IndexedTensor result = (IndexedTensor) tokenEmbeddings; Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens. + int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens. if (tensorType.valueType() == TensorType.Value.INT8) { contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); } else { @@ -230,7 +234,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { - if(result.shape().length != 3) + if (result.shape().length != 3) throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); int size = type.indexedSubtype().dimensions().size(); if (size != 1) @@ -253,8 +257,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) { if (type.valueType() != TensorType.Value.INT8) - throw new IllegalArgumentException("Only a int8 tensor type can be" + - " the destination of bit packing"); + throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing"); if(result.shape().length != 3) throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); @@ -264,8 +267,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[2]; if (resultDimensionality != 8 * wantedDimensionality) { - throw new IllegalArgumentException("Not possible to pack " + resultDimensionality - + " + dimensions into " + wantedDimensionality + " dimensions"); + throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + + " + dimensions into " + wantedDimensionality + " dimensions"); } Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { @@ -302,9 +305,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return unpacker.evaluate(context).asTensor(); } - protected boolean verifyTensorType(TensorType target) { - return target.dimensions().size() == 2 && - target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; + protected boolean validTensorType(TensorType target) { + return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1; } private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { @@ -316,5 +318,5 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } - private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; + } diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 3a64083c623..58bd4deb659 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -25,9 +25,12 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES /** * A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). + * + * @author bergum */ @Beta public class SpladeEmbedder extends AbstractComponent implements Embedder { + private final Embedder.Runtime runtime; private final String inputIdsName; private final String attentionMaskName; @@ -110,7 +113,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { public Tensor embed(String text, Context context, TensorType tensorType) { if (!verifyTensorType(tensorType)) { throw new IllegalArgumentException("Invalid splade embedder tensor destination. " + - "Wanted a mapped 1-d tensor, got " + tensorType); + "Wanted a mapped 1-d tensor, got " + tensorType); } var start = System.nanoTime(); @@ -132,17 +135,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { return spladeTensor; } - /** * Sparsify the output tensor by applying a threshold on the log of the relu of the output. * This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant. + * * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size - * of the vocabulary + * of the vocabulary * @param tensorType the type of the destination tensor * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) { - //Remove batch dim, batch size of 1 + // Remove batch dim, batch size of 1 Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1"); Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0))); IndexedTensor vocab = (IndexedTensor) logOfRelu; @@ -227,6 +230,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } + @Override public void deconstruct() { evaluator.close(); diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index 0cae94c372a..be75c4d3351 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -19,6 +19,9 @@ import java.util.Set; import static org.junit.Assert.*; import static org.junit.Assume.assumeTrue; +/** + * @author bergum + */ public class ColBertEmbedderTest { @Test @@ -67,23 +70,24 @@ public class ColBertEmbedderTest { assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); assertThrows(IllegalArgumentException.class, () -> { - //throws because int8 is not supported for query context + // throws because int8 is not supported for query context assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); }); + assertThrows(IllegalArgumentException.class, () -> { - //throws because 16 is less than model output (128) and we want float + // throws because 16 is less than model output (128) and we want float assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); }); assertThrows(IllegalArgumentException.class, () -> { - //throws because 128/8 does not fit into 15 + // throws because 128/8 does not fit into 15 assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); }); } @Test public void testInputTensorsWordPiece() { - //wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999] + // wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999] List<Long> tokens = List.of(2023L, 2003L, 1037L, 23032L, 999L); ColBertEmbedder.TransformerInput input = embedder.buildTransformerInput(tokens,10,true); assertEquals(10,input.inputIds().size()); @@ -100,7 +104,7 @@ public class ColBertEmbedderTest { @Test public void testInputTensorsSentencePiece() { - //Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711] + // Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711] // ! is mapped to 711 and is a punctuation character List<Long> tokens = List.of(903L, 83L, 10L, 41L, 1294L, 711L); ColBertEmbedder.TransformerInput input = multiLingualEmbedder.buildTransformerInput(tokens,10,true); @@ -109,7 +113,7 @@ public class ColBertEmbedderTest { assertEquals(List.of(0L, 3L, 903L, 83L, 10L, 41L, 1294L, 711L, 2L, 250001L),input.inputIds()); assertEquals(List.of(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 0L),input.attentionMask()); - //NO padding for document side and 711 (punctuation) is now filtered out + // NO padding for document side and 711 (punctuation) is now filtered out input = multiLingualEmbedder.buildTransformerInput(tokens,10,false); assertEquals(8,input.inputIds().size()); assertEquals(8,input.attentionMask().size()); @@ -156,12 +160,12 @@ public class ColBertEmbedderTest { sb.append(" "); } String text = sb.toString(); - Long now = System.currentTimeMillis(); + long now = System.currentTimeMillis(); int n = 1000; for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now); + long elapsed = (System.currentTimeMillis() - now); System.out.println("Elapsed time: " + elapsed + " ms"); } @@ -170,7 +174,7 @@ public class ColBertEmbedderTest { Tensor result = embedder.embed(text, context, destType); assertEquals(destType,result.type()); MixedTensor mixedTensor = (MixedTensor) result; - if(context == queryContext) { + if (context == queryContext) { assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); } return result; @@ -200,12 +204,14 @@ public class ColBertEmbedderTest { static final ColBertEmbedder multiLingualEmbedder; static final Embedder.Context indexingContext; static final Embedder.Context queryContext; + static { indexingContext = new Embedder.Context("schema.indexing"); queryContext = new Embedder.Context("query(qt)"); embedder = getEmbedder(); multiLingualEmbedder = getMultiLingualEmbedder(); } + private static ColBertEmbedder getEmbedder() { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; @@ -235,4 +241,5 @@ public class ColBertEmbedderTest { return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); } + } |