summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-04-07 20:31:37 +0200
committerJon Bratseth <bratseth@vespa.ai>2024-04-07 20:31:37 +0200
commit6715471dceedbbda28d9d29ffb9d441ebfb848a2 (patch)
treee6255566e40817243f3df7a4667cf9e6822baa62 /model-integration
parent9f9160985a4f4848fa3f89d83a9f859958bd8e3c (diff)
Key by embedder id and don't recompute inputs
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java73
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java65
2 files changed, 73 insertions, 65 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index a9d6d308df8..2fd8e312a7e 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -149,7 +149,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) {
- if(!isQuery) {
+ if (!isQuery) {
tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList();
}
List<Long> inputIds = new ArrayList<>(maxTokens);
@@ -172,7 +172,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
attentionMask.add((long) 1);
for (int i = 0; i < padding; i++)
- attentionMask.add((long) 0);//Do not attend to mask paddings
+ attentionMask.add((long) 0); // Do not attend to mask paddings
return new TransformerInput(inputIds, attentionMask);
}
@@ -181,55 +181,44 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (tensorType.valueType() == TensorType.Value.INT8)
throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
- var start = System.nanoTime();
- var encoding = tokenizer.encode(text, context.getLanguage());
- runtime.sampleSequenceLength(encoding.ids().size(), context);
+ EmbeddingResult result = lookupOrEvaluate(context, text, true);
+ return toFloatTensor((IndexedTensor)result.outputs.get(outputName), tensorType, result.inputIdSize);
+ }
- TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true);
- Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
- Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
+ protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
+ EmbeddingResult result = lookupOrEvaluate(context, text, false);
+ var modelOutput = (IndexedTensor)result.outputs.get(outputName);
+ if (tensorType.valueType() == TensorType.Value.INT8)
+ return toBitTensor(modelOutput, tensorType, result.inputIdSize);
+ else
+ return toFloatTensor(modelOutput, tensorType, result.inputIdSize);
+ }
- var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
- attentionMaskName, attentionMaskTensor.expand("d0"));
- IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
- Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size());
- runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return resultTensor;
+ /**
+ * Evaluate the embedding model if the result is not present in the context cache.
+ *
+ * @param context the context accompanying the request
+ * @param text the text that is embedded
+ * @return the model output
+ */
+ protected EmbeddingResult lookupOrEvaluate(Context context, String text, boolean isQuery) {
+ var key = new EmbedderCacheKey(context.getEmbedderId(), text);
+ return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text, isQuery));
}
- protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
+ private EmbeddingResult evaluate(Context context, String text, boolean isQuery) {
var start = System.nanoTime();
-
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
-
- TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false);
+ TransformerInput input = buildTransformerInput(encoding.ids(), isQuery ? maxQueryTokens : maxDocumentTokens, isQuery);
Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
-
- var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
+ var inputs = Map.of(inputIdsName,
+ inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
- IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
- Tensor resultEmbeddings;
- int maxTokens = input.inputIds.size();
- if (tensorType.valueType() == TensorType.Value.INT8) {
- resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens);
- } else {
- resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens);
- }
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return resultEmbeddings;
- }
-
- /**
- * Evaluate the model if the result is not present in the context cache.
- * @param inputs the tensor inputs
- * @param context the context accompanying the request, a singleton per embedder instance and request
- * @param hashKey the key to the cached value
- * @return the model output
- */
- protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
+ return new EmbeddingResult(input.inputIds.size(), outputs);
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
@@ -320,4 +309,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ record EmbedderCacheKey(String embedderId, Object embeddedValue) { }
+
+ record EmbeddingResult(int inputIdSize, Map<String, Tensor> outputs) { }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 5fd0afad2c4..f6216e4149c 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -75,14 +75,15 @@ public class ColBertEmbedderTest {
@Test
public void testCachingFloat() {
+ int initialEmbeddingsDone = runtime.embeddingsDone;
var context = new Embedder.Context("schema.indexing");
+
var input = "This is a test string to embed";
var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
- var modelOuput = context.getCachedValue(input);
+ assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone);
var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])"));
- var modelOuput2 = context.getCachedValue(input);
- assertEquals(modelOuput, modelOuput2);
+ assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone);
assertNotEquals(t1,t2);
for(int token = 0; token < 7; token ++) {
@@ -90,39 +91,38 @@ public class ColBertEmbedderTest {
assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
}
}
- //t2 only has 4 dimensions so this should be out of bounds which returns 0
+ // t2 only has 4 dimensions so this should be out of bounds which returns 0
assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6);
input = "This is a different test string to embed";
embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
- var modelOuput3 = context.getCachedValue(input);
- assertNotEquals(modelOuput, modelOuput3);
- assertNotEquals(modelOuput2, modelOuput3);
+ assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone);
}
@Test
public void testCachingInt() {
+ int initialEmbeddingsDone = runtime.embeddingsDone;
var context = new Embedder.Context("schema.indexing");
+
var input = "This is a test string to embed";
- var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[8])"));
- var modelOuput = context.getCachedValue(input);
+ var t1 = (MixedTensor) embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[8])"));
+ assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone);
- var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[4])"));
- var modelOuput2 = context.getCachedValue(input);
- assertEquals(modelOuput, modelOuput2);
- assertNotEquals(t1,t2);
+ var t2 = (MixedTensor)embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[4])"));
+ assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone);
+
+ assertNotEquals(t1, t2);
for(int token = 0; token < 7; token ++) {
for(int dim = 0; dim < 4; dim++) { // the four first should be equal
- assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
+ assertEquals(t1.get(TensorAddress.of(token,dim)), t2.get(TensorAddress.of(token,dim)), 1e-6);
}
}
- //t2 only has 4 dimensions so this should be out of bounds which returns 0
+ // t2 only has 4 dimensions so this should be out of bounds which returns 0
assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6);
+
input = "This is a different test string to embed";
embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
- var modelOuput3 = context.getCachedValue(input);
- assertNotEquals(modelOuput, modelOuput3);
- assertNotEquals(modelOuput2, modelOuput3);
+ assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone);
}
@@ -274,15 +274,16 @@ public class ColBertEmbedderTest {
}
static final ColBertEmbedder embedder;
-
static final ColBertEmbedder multiLingualEmbedder;
+ static final CountingRuntime runtime;
static {
- embedder = getEmbedder();
- multiLingualEmbedder = getMultiLingualEmbedder();
+ runtime = new CountingRuntime();
+ embedder = createEmbedder(runtime);
+ multiLingualEmbedder = getMultiLingualEmbedder(runtime);
}
- private static ColBertEmbedder getEmbedder() {
+ private static ColBertEmbedder createEmbedder(Embedder.Runtime runtime) {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
@@ -290,10 +291,10 @@ public class ColBertEmbedderTest {
builder.tokenizerPath(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
builder.transformerGpuDevice(-1);
- return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ return new ColBertEmbedder(new OnnxRuntime(), runtime, builder.build());
}
- private static ColBertEmbedder getMultiLingualEmbedder() {
+ private static ColBertEmbedder getMultiLingualEmbedder(Embedder.Runtime runtime) {
String vocabPath = "src/test/models/onnx/transformer/sentence_piece_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
@@ -309,7 +310,21 @@ public class ColBertEmbedderTest {
builder.queryTokenId(3);
builder.documentTokenId(4);
- return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+
+ private static class CountingRuntime implements Embedder.Runtime {
+
+ int embeddingsDone = 0;
+
+ @Override
+ public void sampleEmbeddingLatency(double millis, Embedder.Context ctx) {
+ embeddingsDone++;
+ }
+
+ @Override
+ public void sampleSequenceLength(long length, Embedder.Context ctx) { }
+
}
}