diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2024-04-07 20:31:37 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2024-04-07 20:31:37 +0200 |
commit | 6715471dceedbbda28d9d29ffb9d441ebfb848a2 (patch) | |
tree | e6255566e40817243f3df7a4667cf9e6822baa62 /model-integration | |
parent | 9f9160985a4f4848fa3f89d83a9f859958bd8e3c (diff) |
Key by embedder id and don't recompute inputs
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 73 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 65 |
2 files changed, 73 insertions, 65 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index a9d6d308df8..2fd8e312a7e 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -149,7 +149,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) { - if(!isQuery) { + if (!isQuery) { tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList(); } List<Long> inputIds = new ArrayList<>(maxTokens); @@ -172,7 +172,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { attentionMask.add((long) 1); for (int i = 0; i < padding; i++) - attentionMask.add((long) 0);//Do not attend to mask paddings + attentionMask.add((long) 0); // Do not attend to mask paddings return new TransformerInput(inputIds, attentionMask); } @@ -181,55 +181,44 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (tensorType.valueType() == TensorType.Value.INT8) throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); - var start = System.nanoTime(); - var encoding = tokenizer.encode(text, context.getLanguage()); - runtime.sampleSequenceLength(encoding.ids().size(), context); + EmbeddingResult result = lookupOrEvaluate(context, text, true); + return toFloatTensor((IndexedTensor)result.outputs.get(outputName), tensorType, result.inputIdSize); + } - TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true); - Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); - Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + EmbeddingResult result = lookupOrEvaluate(context, text, false); + var modelOutput = (IndexedTensor)result.outputs.get(outputName); + if (tensorType.valueType() == TensorType.Value.INT8) + return toBitTensor(modelOutput, tensorType, result.inputIdSize); + else + return toFloatTensor(modelOutput, tensorType, result.inputIdSize); + } - var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), - attentionMaskName, attentionMaskTensor.expand("d0")); - IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName); - Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size()); - runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return resultTensor; + /** + * Evaluate the embedding model if the result is not present in the context cache. + * + * @param context the context accompanying the request + * @param text the text that is embedded + * @return the model output + */ + protected EmbeddingResult lookupOrEvaluate(Context context, String text, boolean isQuery) { + var key = new EmbedderCacheKey(context.getEmbedderId(), text); + return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text, isQuery)); } - protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + private EmbeddingResult evaluate(Context context, String text, boolean isQuery) { var start = System.nanoTime(); - var encoding = tokenizer.encode(text, context.getLanguage()); runtime.sampleSequenceLength(encoding.ids().size(), context); - - TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false); + TransformerInput input = buildTransformerInput(encoding.ids(), isQuery ? maxQueryTokens : maxDocumentTokens, isQuery); Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); - - var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + var inputs = Map.of(inputIdsName, + inputIdsTensor.expand("d0"), attentionMaskName, attentionMaskTensor.expand("d0")); - IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName); - Tensor resultEmbeddings; - int maxTokens = input.inputIds.size(); - if (tensorType.valueType() == TensorType.Value.INT8) { - resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens); - } else { - resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens); - } + Map<String, Tensor> outputs = evaluator.evaluate(inputs); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return resultEmbeddings; - } - - /** - * Evaluate the model if the result is not present in the context cache. - * @param inputs the tensor inputs - * @param context the context accompanying the request, a singleton per embedder instance and request - * @param hashKey the key to the cached value - * @return the model output - */ - protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) { - return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs)); + return new EmbeddingResult(input.inputIds.size(), outputs); } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { @@ -320,4 +309,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + record EmbedderCacheKey(String embedderId, Object embeddedValue) { } + + record EmbeddingResult(int inputIdSize, Map<String, Tensor> outputs) { } + } diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index 5fd0afad2c4..f6216e4149c 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -75,14 +75,15 @@ public class ColBertEmbedderTest { @Test public void testCachingFloat() { + int initialEmbeddingsDone = runtime.embeddingsDone; var context = new Embedder.Context("schema.indexing"); + var input = "This is a test string to embed"; var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); - var modelOuput = context.getCachedValue(input); + assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone); var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])")); - var modelOuput2 = context.getCachedValue(input); - assertEquals(modelOuput, modelOuput2); + assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone); assertNotEquals(t1,t2); for(int token = 0; token < 7; token ++) { @@ -90,39 +91,38 @@ public class ColBertEmbedderTest { assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6); } } - //t2 only has 4 dimensions so this should be out of bounds which returns 0 + // t2 only has 4 dimensions so this should be out of bounds which returns 0 assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6); input = "This is a different test string to embed"; embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); - var modelOuput3 = context.getCachedValue(input); - assertNotEquals(modelOuput, modelOuput3); - assertNotEquals(modelOuput2, modelOuput3); + assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone); } @Test public void testCachingInt() { + int initialEmbeddingsDone = runtime.embeddingsDone; var context = new Embedder.Context("schema.indexing"); + var input = "This is a test string to embed"; - var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[8])")); - var modelOuput = context.getCachedValue(input); + var t1 = (MixedTensor) embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[8])")); + assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone); - var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[4])")); - var modelOuput2 = context.getCachedValue(input); - assertEquals(modelOuput, modelOuput2); - assertNotEquals(t1,t2); + var t2 = (MixedTensor)embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[4])")); + assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone); + + assertNotEquals(t1, t2); for(int token = 0; token < 7; token ++) { for(int dim = 0; dim < 4; dim++) { // the four first should be equal - assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6); + assertEquals(t1.get(TensorAddress.of(token,dim)), t2.get(TensorAddress.of(token,dim)), 1e-6); } } - //t2 only has 4 dimensions so this should be out of bounds which returns 0 + // t2 only has 4 dimensions so this should be out of bounds which returns 0 assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6); + input = "This is a different test string to embed"; embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])")); - var modelOuput3 = context.getCachedValue(input); - assertNotEquals(modelOuput, modelOuput3); - assertNotEquals(modelOuput2, modelOuput3); + assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone); } @@ -274,15 +274,16 @@ public class ColBertEmbedderTest { } static final ColBertEmbedder embedder; - static final ColBertEmbedder multiLingualEmbedder; + static final CountingRuntime runtime; static { - embedder = getEmbedder(); - multiLingualEmbedder = getMultiLingualEmbedder(); + runtime = new CountingRuntime(); + embedder = createEmbedder(runtime); + multiLingualEmbedder = getMultiLingualEmbedder(runtime); } - private static ColBertEmbedder getEmbedder() { + private static ColBertEmbedder createEmbedder(Embedder.Runtime runtime) { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); @@ -290,10 +291,10 @@ public class ColBertEmbedderTest { builder.tokenizerPath(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); builder.transformerGpuDevice(-1); - return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + return new ColBertEmbedder(new OnnxRuntime(), runtime, builder.build()); } - private static ColBertEmbedder getMultiLingualEmbedder() { + private static ColBertEmbedder getMultiLingualEmbedder(Embedder.Runtime runtime) { String vocabPath = "src/test/models/onnx/transformer/sentence_piece_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); @@ -309,7 +310,21 @@ public class ColBertEmbedderTest { builder.queryTokenId(3); builder.documentTokenId(4); - return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } + + private static class CountingRuntime implements Embedder.Runtime { + + int embeddingsDone = 0; + + @Override + public void sampleEmbeddingLatency(double millis, Embedder.Context ctx) { + embeddingsDone++; + } + + @Override + public void sampleSequenceLength(long length, Embedder.Context ctx) { } + } } |