diff options
author | Jo Kristian Bergum <bergum@vespa.ai> | 2024-04-10 09:14:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-10 09:14:30 +0200 |
commit | 4d0144a4d249df6cce37539cba13969e9fd4ca4f (patch) | |
tree | 6478d0617dea7b6469a1c269cb54ccad36290095 | |
parent | 8db9ee454f4ae9c677fdf9382fcb51139fbc263d (diff) | |
parent | 4d233b5379b8dc4b94901f8df8acda0a6f2c4420 (diff) |
Merge pull request #30809 from vespa-engine/jobergum/add-context-caching
Add onnx output caching to embedder (allow different post-processing of model outputs)
-rw-r--r-- | indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java | 6 | ||||
-rw-r--r-- | linguistics/abi-spec.json | 5 | ||||
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/process/Embedder.java | 21 | ||||
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 82 | ||||
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 118 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 153 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java (renamed from model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java) | 68 |
7 files changed, 304 insertions, 149 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java index ba07fc00ca8..cdd0c11baac 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java @@ -22,7 +22,7 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter { private final FieldValueAdapter adapter; private FieldValue value; private Language language; - private final Map<String, Object> cache = LazyMap.newHashMap(); + private final Map<Object, Object> cache = LazyMap.newHashMap(); public ExecutionContext() { this(null); @@ -125,12 +125,12 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter { } /** Returns a cached value, or null if not present. */ - public Object getCachedValue(String key) { + public Object getCachedValue(Object key) { return cache.get(key); } /** Returns a mutable reference to the cache of this. */ - public Map<String, Object> getCache() { + public Map<Object, Object> getCache() { return cache; } diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 9f91c32cf62..a4adacc5905 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -346,8 +346,9 @@ "public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)", "public java.lang.String getEmbedderId()", "public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)", - "public void putCachedValue(java.lang.String, java.lang.Object)", - "public java.lang.Object getCachedValue(java.lang.String)" + "public void putCachedValue(java.lang.Object, java.lang.Object)", + "public java.lang.Object getCachedValue(java.lang.Object)", + "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)" ], "fields" : [ ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index 2ab2de303c2..989edcdb18a 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -7,10 +7,10 @@ import com.yahoo.language.Language; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; /** * An embedder converts a text string to a tensor @@ -73,9 +73,10 @@ public interface Embedder { */ @Beta interface Runtime { - /** Sample latency metric for embedding */ + + /** Add a sample embedding latency to this */ void sampleEmbeddingLatency(double millis, Context ctx); - /** Sample sequence length metric for embedding */ + /** Add a sample embedding length to this */ void sampleSequenceLength(long length, Context ctx); static Runtime testInstance() { @@ -91,7 +92,7 @@ public interface Embedder { private Language language = Language.UNKNOWN; private String destination; private String embedderId = "unknown"; - private final Map<String, Object> cache; + private final Map<Object, Object> cache; public Context(String destination) { this(destination, LazyMap.newHashMap()); @@ -101,7 +102,7 @@ public interface Embedder { * @param destination the name of the recipient of this tensor * @param cache a cache shared between all embed invocations for a single request */ - public Context(String destination, Map<String, Object> cache) { + public Context(String destination, Map<Object, Object> cache) { this.destination = destination; this.cache = Objects.requireNonNull(cache); } @@ -153,15 +154,21 @@ public interface Embedder { return this; } - public void putCachedValue(String key, Object value) { + public void putCachedValue(Object key, Object value) { cache.put(key, value); } /** Returns a cached value, or null if not present. */ - public Object getCachedValue(String key) { + public Object getCachedValue(Object key) { return cache.get(key); } + /** Returns the cached value, or computes and caches it if not present. */ + @SuppressWarnings("unchecked") + public <T> T computeCachedValueIfAbsent(Object key, Supplier<? extends T> supplier) { + return (T) cache.computeIfAbsent(key, __ -> supplier.get()); + } + } class FailingEmbedder implements Embedder { diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index f43f3834a65..2fd8e312a7e 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -149,7 +149,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) { - if(!isQuery) { + if (!isQuery) { tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList(); } List<Long> inputIds = new ArrayList<>(maxTokens); @@ -172,7 +172,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { attentionMask.add((long) 1); for (int i = 0; i < padding; i++) - attentionMask.add((long) 0);//Do not attend to mask paddings + attentionMask.add((long) 0); // Do not attend to mask paddings return new TransformerInput(inputIds, attentionMask); } @@ -181,56 +181,44 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (tensorType.valueType() == TensorType.Value.INT8) throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); + EmbeddingResult result = lookupOrEvaluate(context, text, true); + return toFloatTensor((IndexedTensor)result.outputs.get(outputName), tensorType, result.inputIdSize); + } - var start = System.nanoTime(); - var encoding = tokenizer.encode(text, context.getLanguage()); - runtime.sampleSequenceLength(encoding.ids().size(), context); - - TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true); - - Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); - Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); - - var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), - attentionMaskName, attentionMaskTensor.expand("d0")); - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings; + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + EmbeddingResult result = lookupOrEvaluate(context, text, false); + var modelOutput = (IndexedTensor)result.outputs.get(outputName); + if (tensorType.valueType() == TensorType.Value.INT8) + return toBitTensor(modelOutput, tensorType, result.inputIdSize); + else + return toFloatTensor(modelOutput, tensorType, result.inputIdSize); + } - int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if (dims != result.shape()[2]) { - throw new IllegalArgumentException("Token vector dimensionality does not" + - " match indexed dimensionality of " + dims); - } - Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size()); - runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return resultTensor; + /** + * Evaluate the embedding model if the result is not present in the context cache. + * + * @param context the context accompanying the request + * @param text the text that is embedded + * @return the model output + */ + protected EmbeddingResult lookupOrEvaluate(Context context, String text, boolean isQuery) { + var key = new EmbedderCacheKey(context.getEmbedderId(), text); + return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text, isQuery)); } - protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + private EmbeddingResult evaluate(Context context, String text, boolean isQuery) { var start = System.nanoTime(); var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); - - TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false); + TransformerInput input = buildTransformerInput(encoding.ids(), isQuery ? maxQueryTokens : maxDocumentTokens, isQuery); Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); - - var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + var inputs = Map.of(inputIdsName, + inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings; - Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens. - if (tensorType.valueType() == TensorType.Value.INT8) { - contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); - } else { - contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens); - } runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return contextualEmbeddings; + return new EmbeddingResult(input.inputIds.size(), outputs); } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { @@ -241,13 +229,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[2]; - if (resultDimensionality != wantedDimensionality) { + if (wantedDimensionality > resultDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " dimensions into tensor with " + wantedDimensionality); } Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { - for (int d = 0; d < resultDimensionality; d++) { + for (int d = 0; d < wantedDimensionality; d++) { var value = result.get(0,token,d); // batch, sequence token, dimension builder.cell(TensorAddress.of(token,d),value); } @@ -265,8 +253,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (size != 1) throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + //Allow using the first n float dimensions to pack into int8 + int floatDimensionality = 8 * wantedDimensionality; int resultDimensionality = (int)result.shape()[2]; - if (resultDimensionality != 8 * wantedDimensionality) { + if (floatDimensionality > resultDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); } @@ -274,7 +264,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; - for (int d = 0; d < result.shape()[2]; d++) { + for (int d = 0; d < floatDimensionality; d++) { var value = result.get(0, token, d); // batch, sequence token, dimension int bitIndex = 7 - (d % 8); if (value > 0.0) { @@ -319,4 +309,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + record EmbedderCacheKey(String embedderId, Object embeddedValue) { } + + record EmbeddingResult(int inputIdSize, Map<String, Tensor> outputs) { } + } diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 169648967d7..20d8b6362d3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -104,59 +104,23 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer.close(); } + @SuppressWarnings("unchecked") @Override - public Tensor embed(String s, Context context, TensorType tensorType) { - var start = System.nanoTime(); - var encoding = tokenizer.encode(s, context.getLanguage()); - runtime.sampleSequenceLength(encoding.ids().size(), context); - Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); - Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); - Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); - - - Map<String, Tensor> inputs; - if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { - inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0")); - } else { - inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0"), - tokenTypeIdsName, tokenTypeIds.expand("d0")); + public Tensor embed(String text, Context context, TensorType tensorType) { + if (tensorType.dimensions().size() != 1) { + throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension."); } - - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName); - long[] resultShape = tokenEmbeddings.shape(); - //shape batch, sequence, embedding dimensionality - if (resultShape.length != 3) { - throw new IllegalArgumentException("" + - "Expected 3 output dimensions for output name '" + - outputName + "': [batch, sequence, embedding], got " + resultShape.length); + if (!tensorType.dimensions().get(0).isIndexed()) { + throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed."); } - Tensor result; + var embeddingResult = lookupOrEvaluate(context, text); + IndexedTensor tokenEmbeddings = embeddingResult.output; if (tensorType.valueType() == TensorType.Value.INT8) { - long outputDimensions = resultShape[2]; - long targetDim = tensorType.dimensions().get(0).size().get(); - - if(targetDim * 8 > outputDimensions) { - throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); - } - //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output - long firstDimensions = 8 * targetDim; - String name = tensorType.indexedSubtype().dimensions().get(0).name(); - //perform pooling and normalizing using floating point embeddings before binarizing - //using the firstDimensions as the target dimensionality - TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build(); - result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask); - result = normalize? normalize(result, poolingType) : result; - result = binarize((IndexedTensor) result, tensorType); - - } else { // regular floating points embeddings - result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); - result = normalize ? normalize(result, tensorType) : result; + return binaryQuantization(embeddingResult, tensorType); + } else { + Tensor result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, embeddingResult.attentionMask); + return normalize ? normalize(result, tensorType) : result; } - runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); - return result; } Tensor normalize(Tensor embedding, TensorType tensorType) { @@ -178,6 +142,61 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + private HuggingFaceEmbedder.HFEmbeddingResult lookupOrEvaluate(Context context, String text) { + var key = new HFEmbedderCacheKey(context.getEmbedderId(), text); + return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text)); + } + + private HuggingFaceEmbedder.HFEmbeddingResult evaluate(Context context, String text) { + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); + Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); + Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); + + Map<String, Tensor> inputs; + if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0"), + tokenTypeIdsName, tokenTypeIds.expand("d0")); + } + IndexedTensor tokenEmbeddings = (IndexedTensor) evaluator.evaluate(inputs).get(outputName); + long[] resultShape = tokenEmbeddings.shape(); + //shape batch, sequence, embedding dimensionality + if (resultShape.length != 3) { + throw new IllegalArgumentException("" + + "Expected 3 output dimensions for output name '" + + outputName + "': [batch, sequence, embedding], got " + resultShape.length); + } + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId()); + } + + private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType tensorType) { + long outputDimensions = embeddingResult.output().shape()[2]; + long targetDim = tensorType.dimensions().get(0).size().get(); + //🪆 flexibility - packing only the first 8*targetDim float values from the model output + long floatDimensions = 8 * targetDim; + if(floatDimensions > outputDimensions) { + throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); + } + //perform pooling and normalizing using float version before binary quantization + TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT). + indexed(tensorType.indexedSubtype().dimensions().get(0).name(), + floatDimensions).build(); + Tensor result = poolingStrategy.toSentenceEmbedding(poolingType, embeddingResult.output(), embeddingResult.attentionMask()); + result = normalize? normalize(result, poolingType) : result; + result = binarize((IndexedTensor) result, tensorType); + return result; + } + + /** + * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions. + */ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { Tensor.Builder builder = Tensor.Builder.of(tensorType); BitSet bitSet = new BitSet(8); @@ -211,6 +230,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - + protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, String embedderId) {} + protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { } } diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index be75c4d3351..f6216e4149c 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -61,27 +61,94 @@ public class ColBertEmbedderTest { TensorType.fromSpec("tensor<int8>(dt{},x[2])"), "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2 ); + assertPackedRight( + "" + + "tensor<float>(d0[1],d1[2],d2[16]):" + + "[[" + + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + + "]]", + TensorType.fromSpec("tensor<int8>(dt{},x[1])"), + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0}",2 + ); + } + + @Test + public void testCachingFloat() { + int initialEmbeddingsDone = runtime.embeddingsDone; + var context = new Embedder.Context("schema.indexing"); + + var input = "This is a test string to embed"; + var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); + assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone); + + var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])")); + assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone); + + assertNotEquals(t1,t2); + for(int token = 0; token < 7; token ++) { + for(int dim = 0; dim < 4; dim++) { // the four first should be equal + assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6); + } + } + // t2 only has 4 dimensions so this should be out of bounds which returns 0 + assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6); + + input = "This is a different test string to embed"; + embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); + assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone); + } + + @Test + public void testCachingInt() { + int initialEmbeddingsDone = runtime.embeddingsDone; + var context = new Embedder.Context("schema.indexing"); + + var input = "This is a test string to embed"; + var t1 = (MixedTensor) embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[8])")); + assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone); + + var t2 = (MixedTensor)embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[4])")); + assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone); + + assertNotEquals(t1, t2); + for(int token = 0; token < 7; token ++) { + for(int dim = 0; dim < 4; dim++) { // the four first should be equal + assertEquals(t1.get(TensorAddress.of(token,dim)), t2.get(TensorAddress.of(token,dim)), 1e-6); + } + } + // t2 only has 4 dimensions so this should be out of bounds which returns 0 + assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6); + + input = "This is a different test string to embed"; + embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); + assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone); } + @Test public void testEmbedder() { - assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext); - assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext); - assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); + var indexingContext = new Embedder.Context("schema.indexing"); + assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext,128); + assertEmbed("tensor<float>(dt{},x[64])", "this is a document", indexingContext,64); - assertThrows(IllegalArgumentException.class, () -> { - // throws because int8 is not supported for query context - assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); - }); + assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext,16); + assertEmbed("tensor<int8>(dt{},x[8])", "this is a document", indexingContext,8); + assertEmbed("tensor<int8>(dt{},x[4])", "this is a document", indexingContext,4); + assertEmbed("tensor<int8>(dt{},x[3])", "this is a document", indexingContext,3); + + var queryContext = new Embedder.Context("query(qt{})"); + assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext,128); + assertEmbed("tensor<float>(qt{},x[64])", "this is a query", queryContext,64); assertThrows(IllegalArgumentException.class, () -> { - // throws because 16 is less than model output (128) and we want float - assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); + // throws because int8 is not supported for query context + assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext,16); }); assertThrows(IllegalArgumentException.class, () -> { - // throws because 128/8 does not fit into 15 - assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); + // throws because 8*32 is larger than (128) + assertEmbed("tensor<int8>(qt{},x[32])", "this is a query", queryContext,32); }); } @@ -130,26 +197,32 @@ public class ColBertEmbedderTest { } @Test - public void testLenghtLimits() { + public void testLengthLimits() { StringBuilder sb = new StringBuilder(); for(int i = 0; i < 1024; i++) { sb.append("annoyance"); sb.append(" "); } String text = sb.toString(); - Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); - assertEquals(512*128,fullFloat.size()); + var indexingContext = new Embedder.Context("schema.indexing"); - Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); - assertEquals(32*128,query.size()); + Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext,128); + assertEquals(512*128,fullFloat.size()); - Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); + Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext,16); assertEquals(512*16,binaryRep.size()); - Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); + Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext,16); // 4 tokens, 16 bytes each = 64 bytes //CLS [unused1] sequence assertEquals(4*16,shortDoc.size());; + + var queryContext = new Embedder.Context("query(qt{})"); + Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext,128); + assertEquals(32*128,query.size()); + + Tensor shortQuery = assertEmbed("tensor<float>(dt{},x[64])", text, queryContext,64); + assertEquals(32*64,shortQuery.size()); } @Ignore @@ -163,18 +236,19 @@ public class ColBertEmbedderTest { long now = System.currentTimeMillis(); int n = 1000; for (int i = 0; i < n; i++) { - assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + assertEmbed("tensor<float>(dt{},x[128])", text, new Embedder.Context("schema.indexing"),128); } long elapsed = (System.currentTimeMillis() - now); System.out.println("Elapsed time: " + elapsed + " ms"); } - static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { + static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context, int dimSize) { TensorType destType = TensorType.fromSpec(tensorSpec); Tensor result = embedder.embed(text, context, destType); assertEquals(destType,result.type()); MixedTensor mixedTensor = (MixedTensor) result; - if (context == queryContext) { + assertEquals(dimSize,mixedTensor.denseSubspaceSize()); + if (context.getDestination().startsWith("query")) { assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); } return result; @@ -182,12 +256,12 @@ public class ColBertEmbedderTest { static void assertPackedRight(String numbers, TensorType destination, String expected, int size) { var in = (IndexedTensor) Tensor.from(numbers); + int targetDim = destination.indexedSubtype().dimensions().get(0).size().get().intValue(); Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size); assertEquals(expected, packed.toString()); Tensor unpacked = ColBertEmbedder.expandBitTensor(packed); - assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue()); for (int dOuter = 0; dOuter < size; dOuter++) { - for (int dInner = 0; dInner < in.shape()[2]; dInner++) { + for (int dInner = 0; dInner < targetDim*8; dInner++) { var addr = TensorAddress.of(dOuter, dInner); double oldVal = in.get(TensorAddress.of(0,dOuter, dInner)); if (oldVal > 0) { @@ -200,19 +274,16 @@ public class ColBertEmbedderTest { } static final ColBertEmbedder embedder; - static final ColBertEmbedder multiLingualEmbedder; - static final Embedder.Context indexingContext; - static final Embedder.Context queryContext; + static final CountingRuntime runtime; static { - indexingContext = new Embedder.Context("schema.indexing"); - queryContext = new Embedder.Context("query(qt)"); - embedder = getEmbedder(); - multiLingualEmbedder = getMultiLingualEmbedder(); + runtime = new CountingRuntime(); + embedder = createEmbedder(runtime); + multiLingualEmbedder = getMultiLingualEmbedder(runtime); } - private static ColBertEmbedder getEmbedder() { + private static ColBertEmbedder createEmbedder(Embedder.Runtime runtime) { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); @@ -220,10 +291,10 @@ public class ColBertEmbedderTest { builder.tokenizerPath(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); builder.transformerGpuDevice(-1); - return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + return new ColBertEmbedder(new OnnxRuntime(), runtime, builder.build()); } - private static ColBertEmbedder getMultiLingualEmbedder() { + private static ColBertEmbedder getMultiLingualEmbedder(Embedder.Runtime runtime) { String vocabPath = "src/test/models/onnx/transformer/sentence_piece_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); @@ -239,7 +310,21 @@ public class ColBertEmbedderTest { builder.queryTokenId(3); builder.documentTokenId(4); - return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } + + private static class CountingRuntime implements Embedder.Runtime { + + int embeddingsDone = 0; + + @Override + public void sampleEmbeddingLatency(double millis, Embedder.Context ctx) { + embeddingsDone++; + } + + @Override + public void sampleSequenceLength(long length, Embedder.Context ctx) { } + } } diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java index 1ce1d955b00..d504d77cc9b 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java @@ -1,7 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.embedding; +package ai.vespa.embedding.huggingface; + -import ai.vespa.embedding.huggingface.HuggingFaceEmbedder; import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; @@ -12,10 +12,10 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorAddress; import org.junit.Test; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assume.assumeTrue; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -26,7 +26,6 @@ public class HuggingFaceEmbedderTest { static HuggingFaceEmbedder embedder = getEmbedder(); static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder(); - static Embedder.Context context = new Embedder.Context("schema.indexing"); @Test public void testBinarization() { @@ -48,16 +47,48 @@ public class HuggingFaceEmbedderTest { private void assertPackRight(String input, String expected, TensorType type) { Tensor inputTensor = Tensor.from(input); Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type); - assertEquals(expected.toString(), result.toString()); - //Verify against what is done in ranking with unpack_bits + assertEquals(expected, result.toString()); + //Verify that the unpack_bits ranking feature produce compatible output Tensor unpacked = expandBitTensor(result); assertEquals(inputTensor.toString(), unpacked.toString()); } @Test + public void testCaching() { + var context = new Embedder.Context("schema.indexing"); + var myEmbedderId = "my-hf-embedder"; + context.setEmbedderId(myEmbedderId); + + var input = "This is a test string to embed"; + Tensor result = embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])")); + HuggingFaceEmbedder.HFEmbedderCacheKey key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, input); + var modelOuput = context.getCachedValue(key); + assertNotNull(modelOuput); + + Tensor binaryResult = embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(x[4])")); + var modelOuput2 = context.getCachedValue(key); + assertEquals(modelOuput, modelOuput2); + assertNotEquals(result, binaryResult); + + var anotherInput = "This is a different test string to embed with the same embedder"; + embedder.embed(anotherInput, context,TensorType.fromSpec("tensor<float>(x[4])")); + key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, anotherInput); + var modelOuput3 = context.getCachedValue(key); + assertNotEquals(modelOuput, modelOuput3); + + //context cache is shared + var copyContext = context.copy(); + var anotherEmbedderId = "another-hf-embedder"; + copyContext.setEmbedderId(anotherEmbedderId); + key = new HuggingFaceEmbedder.HFEmbedderCacheKey(anotherEmbedderId, input); + assertNull(copyContext.getCachedValue(key)); + embedder.embed(input, copyContext,TensorType.fromSpec("tensor<int8>(x[2])")); + assertNotEquals(modelOuput, copyContext.getCachedValue(key)); + } + @Test public void testEmbedder() { + var context = new Embedder.Context("schema.indexing"); String input = "This is a test"; - Tensor expected = Tensor.from("tensor<float>(x[8]):[-0.666, 0.335, 0.227, 0.0919, -0.069, 0.323, 0.422, 0.270]"); Tensor result = embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])"))); for(int i = 0; i < 8; i++) { @@ -85,16 +116,33 @@ public class HuggingFaceEmbedderTest { @Test public void testEmbedderWithNormalization() { String input = "This is a test"; - + var context = new Embedder.Context("schema.indexing"); Tensor result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])"))); assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); - result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])"))); assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3); Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])"))); assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString()); } + @Test + public void testThatWrongTensorTypeThrows() { + var context = new Embedder.Context("schema.indexing"); + String input = "This is a test"; + assertThrows(IllegalArgumentException.class, () -> { + // throws because the target tensor type is mapped + embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{})"))); + }); + assertThrows(IllegalArgumentException.class, () -> { + // throws because the target tensor is 0d + embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[0]"))); + }); + assertThrows(IllegalArgumentException.class, () -> { + // throws because the target tensor is 2d + embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{}, y[2])"))); + }); + } + private static HuggingFaceEmbedder getEmbedder() { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx"; |