diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-04 09:15:10 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-04 09:15:10 +0200 |
commit | 531bc532c592703221e232d817850d802cdcfd11 (patch) | |
tree | 69d9a60d6a8ea48dbea331906e775589bce15dd7 | |
parent | a009cdd704f427282c3c9ed3b70a7caf9d536c7e (diff) |
Support for dimensionality flexbility and caching onnx inference output using Context cache
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 60 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 124 |
2 files changed, 131 insertions, 53 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index f43f3834a65..2f4c0343bf6 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -181,34 +181,25 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (tensorType.valueType() == TensorType.Value.INT8) throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); - var start = System.nanoTime(); var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true); - Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings; - - int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if (dims != result.shape()[2]) { - throw new IllegalArgumentException("Token vector dimensionality does not" + - " match indexed dimensionality of " + dims); - } - Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size()); + IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName); + Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return resultTensor; } - + @SuppressWarnings("unchecked") protected Tensor embedDocument(String text, Context context, TensorType tensorType) { var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); @@ -218,19 +209,34 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); - - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings; - Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens. + IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName); + Tensor resultEmbeddings; + int maxTokens = input.inputIds.size(); if (tensorType.valueType() == TensorType.Value.INT8) { - contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); + resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens); } else { - contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens); + resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens); } runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return contextualEmbeddings; + return resultEmbeddings; + } + + /** + * Evaluate the model if the result is not present in the context cache. + * @param inputs the tensor inputs + * @param context the context accompanying the request, a singleton per embedder instance and request + * @param hashKey the key to the cached value + * @return the model output + */ + @SuppressWarnings("unchecked") + protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) { + if (context.getCachedValue(hashKey) == null) { + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + context.putCachedValue(hashKey, outputs); + return outputs; + } else { + return (Map<String, Tensor>) context.getCachedValue(hashKey); + } } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { @@ -241,13 +247,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[2]; - if (resultDimensionality != wantedDimensionality) { + if (wantedDimensionality > resultDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " dimensions into tensor with " + wantedDimensionality); } Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { - for (int d = 0; d < resultDimensionality; d++) { + for (int d = 0; d < wantedDimensionality; d++) { var value = result.get(0,token,d); // batch, sequence token, dimension builder.cell(TensorAddress.of(token,d),value); } @@ -265,8 +271,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (size != 1) throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + //Allow using the first n float dimensions to pack into int8 + int floatDimensionality = 8 * wantedDimensionality; int resultDimensionality = (int)result.shape()[2]; - if (resultDimensionality != 8 * wantedDimensionality) { + if (floatDimensionality > resultDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); } @@ -274,7 +282,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; - for (int d = 0; d < result.shape()[2]; d++) { + for (int d = 0; d < floatDimensionality; d++) { var value = result.get(0, token, d); // batch, sequence token, dimension int bitIndex = 7 - (d % 8); if (value > 0.0) { diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index be75c4d3351..5fd0afad2c4 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -61,27 +61,94 @@ public class ColBertEmbedderTest { TensorType.fromSpec("tensor<int8>(dt{},x[2])"), "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2 ); + assertPackedRight( + "" + + "tensor<float>(d0[1],d1[2],d2[16]):" + + "[[" + + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + + "]]", + TensorType.fromSpec("tensor<int8>(dt{},x[1])"), + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0}",2 + ); + } + + @Test + public void testCachingFloat() { + var context = new Embedder.Context("schema.indexing"); + var input = "This is a test string to embed"; + var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); + var modelOuput = context.getCachedValue(input); + + var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])")); + var modelOuput2 = context.getCachedValue(input); + assertEquals(modelOuput, modelOuput2); + + assertNotEquals(t1,t2); + for(int token = 0; token < 7; token ++) { + for(int dim = 0; dim < 4; dim++) { // the four first should be equal + assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6); + } + } + //t2 only has 4 dimensions so this should be out of bounds which returns 0 + assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6); + + input = "This is a different test string to embed"; + embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); + var modelOuput3 = context.getCachedValue(input); + assertNotEquals(modelOuput, modelOuput3); + assertNotEquals(modelOuput2, modelOuput3); + } + + @Test + public void testCachingInt() { + var context = new Embedder.Context("schema.indexing"); + var input = "This is a test string to embed"; + var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[8])")); + var modelOuput = context.getCachedValue(input); + + var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[4])")); + var modelOuput2 = context.getCachedValue(input); + assertEquals(modelOuput, modelOuput2); + assertNotEquals(t1,t2); + for(int token = 0; token < 7; token ++) { + for(int dim = 0; dim < 4; dim++) { // the four first should be equal + assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6); + } + } + //t2 only has 4 dimensions so this should be out of bounds which returns 0 + assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6); + input = "This is a different test string to embed"; + embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); + var modelOuput3 = context.getCachedValue(input); + assertNotEquals(modelOuput, modelOuput3); + assertNotEquals(modelOuput2, modelOuput3); } + @Test public void testEmbedder() { - assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext); - assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext); - assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); + var indexingContext = new Embedder.Context("schema.indexing"); + assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext,128); + assertEmbed("tensor<float>(dt{},x[64])", "this is a document", indexingContext,64); - assertThrows(IllegalArgumentException.class, () -> { - // throws because int8 is not supported for query context - assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); - }); + assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext,16); + assertEmbed("tensor<int8>(dt{},x[8])", "this is a document", indexingContext,8); + assertEmbed("tensor<int8>(dt{},x[4])", "this is a document", indexingContext,4); + assertEmbed("tensor<int8>(dt{},x[3])", "this is a document", indexingContext,3); + + var queryContext = new Embedder.Context("query(qt{})"); + assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext,128); + assertEmbed("tensor<float>(qt{},x[64])", "this is a query", queryContext,64); assertThrows(IllegalArgumentException.class, () -> { - // throws because 16 is less than model output (128) and we want float - assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); + // throws because int8 is not supported for query context + assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext,16); }); assertThrows(IllegalArgumentException.class, () -> { - // throws because 128/8 does not fit into 15 - assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); + // throws because 8*32 is larger than (128) + assertEmbed("tensor<int8>(qt{},x[32])", "this is a query", queryContext,32); }); } @@ -130,26 +197,32 @@ public class ColBertEmbedderTest { } @Test - public void testLenghtLimits() { + public void testLengthLimits() { StringBuilder sb = new StringBuilder(); for(int i = 0; i < 1024; i++) { sb.append("annoyance"); sb.append(" "); } String text = sb.toString(); - Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); - assertEquals(512*128,fullFloat.size()); + var indexingContext = new Embedder.Context("schema.indexing"); - Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); - assertEquals(32*128,query.size()); + Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext,128); + assertEquals(512*128,fullFloat.size()); - Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); + Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext,16); assertEquals(512*16,binaryRep.size()); - Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); + Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext,16); // 4 tokens, 16 bytes each = 64 bytes //CLS [unused1] sequence assertEquals(4*16,shortDoc.size());; + + var queryContext = new Embedder.Context("query(qt{})"); + Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext,128); + assertEquals(32*128,query.size()); + + Tensor shortQuery = assertEmbed("tensor<float>(dt{},x[64])", text, queryContext,64); + assertEquals(32*64,shortQuery.size()); } @Ignore @@ -163,18 +236,19 @@ public class ColBertEmbedderTest { long now = System.currentTimeMillis(); int n = 1000; for (int i = 0; i < n; i++) { - assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + assertEmbed("tensor<float>(dt{},x[128])", text, new Embedder.Context("schema.indexing"),128); } long elapsed = (System.currentTimeMillis() - now); System.out.println("Elapsed time: " + elapsed + " ms"); } - static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { + static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context, int dimSize) { TensorType destType = TensorType.fromSpec(tensorSpec); Tensor result = embedder.embed(text, context, destType); assertEquals(destType,result.type()); MixedTensor mixedTensor = (MixedTensor) result; - if (context == queryContext) { + assertEquals(dimSize,mixedTensor.denseSubspaceSize()); + if (context.getDestination().startsWith("query")) { assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); } return result; @@ -182,12 +256,12 @@ public class ColBertEmbedderTest { static void assertPackedRight(String numbers, TensorType destination, String expected, int size) { var in = (IndexedTensor) Tensor.from(numbers); + int targetDim = destination.indexedSubtype().dimensions().get(0).size().get().intValue(); Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size); assertEquals(expected, packed.toString()); Tensor unpacked = ColBertEmbedder.expandBitTensor(packed); - assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue()); for (int dOuter = 0; dOuter < size; dOuter++) { - for (int dInner = 0; dInner < in.shape()[2]; dInner++) { + for (int dInner = 0; dInner < targetDim*8; dInner++) { var addr = TensorAddress.of(dOuter, dInner); double oldVal = in.get(TensorAddress.of(0,dOuter, dInner)); if (oldVal > 0) { @@ -202,12 +276,8 @@ public class ColBertEmbedderTest { static final ColBertEmbedder embedder; static final ColBertEmbedder multiLingualEmbedder; - static final Embedder.Context indexingContext; - static final Embedder.Context queryContext; static { - indexingContext = new Embedder.Context("schema.indexing"); - queryContext = new Embedder.Context("query(qt)"); embedder = getEmbedder(); multiLingualEmbedder = getMultiLingualEmbedder(); } |