summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java6
-rw-r--r--linguistics/abi-spec.json6
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java73
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java65
5 files changed, 87 insertions, 78 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
index ba07fc00ca8..cdd0c11baac 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
@@ -22,7 +22,7 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
private final FieldValueAdapter adapter;
private FieldValue value;
private Language language;
- private final Map<String, Object> cache = LazyMap.newHashMap();
+ private final Map<Object, Object> cache = LazyMap.newHashMap();
public ExecutionContext() {
this(null);
@@ -125,12 +125,12 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
}
/** Returns a cached value, or null if not present. */
- public Object getCachedValue(String key) {
+ public Object getCachedValue(Object key) {
return cache.get(key);
}
/** Returns a mutable reference to the cache of this. */
- public Map<String, Object> getCache() {
+ public Map<Object, Object> getCache() {
return cache;
}
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 91574133658..58e28fd7975 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -346,9 +346,9 @@
"public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)",
"public java.lang.String getEmbedderId()",
"public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)",
- "public void putCachedValue(java.lang.String, java.lang.Object)",
- "public java.lang.Object getCachedValue(java.lang.String)",
- "public java.lang.Object computeCachedValueIfAbsent(java.lang.String, java.util.function.Supplier)"
+ "public void putCachedValue(java.lang.Object, java.lang.Object)",
+ "public java.lang.Object getCachedValue(java.lang.Object)",
+ "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)"
],
"fields" : [ ]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index e53f79d98ec..989edcdb18a 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -73,9 +73,10 @@ public interface Embedder {
*/
@Beta
interface Runtime {
- /** Sample latency metric for embedding */
+
+ /** Add a sample embedding latency to this */
void sampleEmbeddingLatency(double millis, Context ctx);
- /** Sample sequence length metric for embedding */
+ /** Add a sample embedding length to this */
void sampleSequenceLength(long length, Context ctx);
static Runtime testInstance() {
@@ -91,7 +92,7 @@ public interface Embedder {
private Language language = Language.UNKNOWN;
private String destination;
private String embedderId = "unknown";
- private final Map<String, Object> cache;
+ private final Map<Object, Object> cache;
public Context(String destination) {
this(destination, LazyMap.newHashMap());
@@ -101,7 +102,7 @@ public interface Embedder {
* @param destination the name of the recipient of this tensor
* @param cache a cache shared between all embed invocations for a single request
*/
- public Context(String destination, Map<String, Object> cache) {
+ public Context(String destination, Map<Object, Object> cache) {
this.destination = destination;
this.cache = Objects.requireNonNull(cache);
}
@@ -153,18 +154,18 @@ public interface Embedder {
return this;
}
- public void putCachedValue(String key, Object value) {
+ public void putCachedValue(Object key, Object value) {
cache.put(key, value);
}
/** Returns a cached value, or null if not present. */
- public Object getCachedValue(String key) {
+ public Object getCachedValue(Object key) {
return cache.get(key);
}
/** Returns the cached value, or computes and caches it if not present. */
@SuppressWarnings("unchecked")
- public <T> T computeCachedValueIfAbsent(String key, Supplier<? extends T> supplier) {
+ public <T> T computeCachedValueIfAbsent(Object key, Supplier<? extends T> supplier) {
return (T) cache.computeIfAbsent(key, __ -> supplier.get());
}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index a9d6d308df8..2fd8e312a7e 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -149,7 +149,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) {
- if(!isQuery) {
+ if (!isQuery) {
tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList();
}
List<Long> inputIds = new ArrayList<>(maxTokens);
@@ -172,7 +172,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
attentionMask.add((long) 1);
for (int i = 0; i < padding; i++)
- attentionMask.add((long) 0);//Do not attend to mask paddings
+ attentionMask.add((long) 0); // Do not attend to mask paddings
return new TransformerInput(inputIds, attentionMask);
}
@@ -181,55 +181,44 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (tensorType.valueType() == TensorType.Value.INT8)
throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
- var start = System.nanoTime();
- var encoding = tokenizer.encode(text, context.getLanguage());
- runtime.sampleSequenceLength(encoding.ids().size(), context);
+ EmbeddingResult result = lookupOrEvaluate(context, text, true);
+ return toFloatTensor((IndexedTensor)result.outputs.get(outputName), tensorType, result.inputIdSize);
+ }
- TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true);
- Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
- Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
+ protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
+ EmbeddingResult result = lookupOrEvaluate(context, text, false);
+ var modelOutput = (IndexedTensor)result.outputs.get(outputName);
+ if (tensorType.valueType() == TensorType.Value.INT8)
+ return toBitTensor(modelOutput, tensorType, result.inputIdSize);
+ else
+ return toFloatTensor(modelOutput, tensorType, result.inputIdSize);
+ }
- var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
- attentionMaskName, attentionMaskTensor.expand("d0"));
- IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
- Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size());
- runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return resultTensor;
+ /**
+ * Evaluate the embedding model if the result is not present in the context cache.
+ *
+ * @param context the context accompanying the request
+ * @param text the text that is embedded
+ * @return the model output
+ */
+ protected EmbeddingResult lookupOrEvaluate(Context context, String text, boolean isQuery) {
+ var key = new EmbedderCacheKey(context.getEmbedderId(), text);
+ return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text, isQuery));
}
- protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
+ private EmbeddingResult evaluate(Context context, String text, boolean isQuery) {
var start = System.nanoTime();
-
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
-
- TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false);
+ TransformerInput input = buildTransformerInput(encoding.ids(), isQuery ? maxQueryTokens : maxDocumentTokens, isQuery);
Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
-
- var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
+ var inputs = Map.of(inputIdsName,
+ inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
- IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
- Tensor resultEmbeddings;
- int maxTokens = input.inputIds.size();
- if (tensorType.valueType() == TensorType.Value.INT8) {
- resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens);
- } else {
- resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens);
- }
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return resultEmbeddings;
- }
-
- /**
- * Evaluate the model if the result is not present in the context cache.
- * @param inputs the tensor inputs
- * @param context the context accompanying the request, a singleton per embedder instance and request
- * @param hashKey the key to the cached value
- * @return the model output
- */
- protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
- return context.computeCachedValueIfAbsent(hashKey, () -> evaluator.evaluate(inputs));
+ return new EmbeddingResult(input.inputIds.size(), outputs);
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
@@ -320,4 +309,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ record EmbedderCacheKey(String embedderId, Object embeddedValue) { }
+
+ record EmbeddingResult(int inputIdSize, Map<String, Tensor> outputs) { }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 5fd0afad2c4..f6216e4149c 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -75,14 +75,15 @@ public class ColBertEmbedderTest {
@Test
public void testCachingFloat() {
+ int initialEmbeddingsDone = runtime.embeddingsDone;
var context = new Embedder.Context("schema.indexing");
+
var input = "This is a test string to embed";
var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
- var modelOuput = context.getCachedValue(input);
+ assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone);
var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])"));
- var modelOuput2 = context.getCachedValue(input);
- assertEquals(modelOuput, modelOuput2);
+ assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone);
assertNotEquals(t1,t2);
for(int token = 0; token < 7; token ++) {
@@ -90,39 +91,38 @@ public class ColBertEmbedderTest {
assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
}
}
- //t2 only has 4 dimensions so this should be out of bounds which returns 0
+ // t2 only has 4 dimensions so this should be out of bounds which returns 0
assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6);
input = "This is a different test string to embed";
embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
- var modelOuput3 = context.getCachedValue(input);
- assertNotEquals(modelOuput, modelOuput3);
- assertNotEquals(modelOuput2, modelOuput3);
+ assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone);
}
@Test
public void testCachingInt() {
+ int initialEmbeddingsDone = runtime.embeddingsDone;
var context = new Embedder.Context("schema.indexing");
+
var input = "This is a test string to embed";
- var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[8])"));
- var modelOuput = context.getCachedValue(input);
+ var t1 = (MixedTensor) embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[8])"));
+ assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone);
- var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[4])"));
- var modelOuput2 = context.getCachedValue(input);
- assertEquals(modelOuput, modelOuput2);
- assertNotEquals(t1,t2);
+ var t2 = (MixedTensor)embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[4])"));
+ assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone);
+
+ assertNotEquals(t1, t2);
for(int token = 0; token < 7; token ++) {
for(int dim = 0; dim < 4; dim++) { // the four first should be equal
- assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
+ assertEquals(t1.get(TensorAddress.of(token,dim)), t2.get(TensorAddress.of(token,dim)), 1e-6);
}
}
- //t2 only has 4 dimensions so this should be out of bounds which returns 0
+ // t2 only has 4 dimensions so this should be out of bounds which returns 0
assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6);
+
input = "This is a different test string to embed";
embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
- var modelOuput3 = context.getCachedValue(input);
- assertNotEquals(modelOuput, modelOuput3);
- assertNotEquals(modelOuput2, modelOuput3);
+ assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone);
}
@@ -274,15 +274,16 @@ public class ColBertEmbedderTest {
}
static final ColBertEmbedder embedder;
-
static final ColBertEmbedder multiLingualEmbedder;
+ static final CountingRuntime runtime;
static {
- embedder = getEmbedder();
- multiLingualEmbedder = getMultiLingualEmbedder();
+ runtime = new CountingRuntime();
+ embedder = createEmbedder(runtime);
+ multiLingualEmbedder = getMultiLingualEmbedder(runtime);
}
- private static ColBertEmbedder getEmbedder() {
+ private static ColBertEmbedder createEmbedder(Embedder.Runtime runtime) {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
@@ -290,10 +291,10 @@ public class ColBertEmbedderTest {
builder.tokenizerPath(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
builder.transformerGpuDevice(-1);
- return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ return new ColBertEmbedder(new OnnxRuntime(), runtime, builder.build());
}
- private static ColBertEmbedder getMultiLingualEmbedder() {
+ private static ColBertEmbedder getMultiLingualEmbedder(Embedder.Runtime runtime) {
String vocabPath = "src/test/models/onnx/transformer/sentence_piece_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
@@ -309,7 +310,21 @@ public class ColBertEmbedderTest {
builder.queryTokenId(3);
builder.documentTokenId(4);
- return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+
+ private static class CountingRuntime implements Embedder.Runtime {
+
+ int embeddingsDone = 0;
+
+ @Override
+ public void sampleEmbeddingLatency(double millis, Embedder.Context ctx) {
+ embeddingsDone++;
+ }
+
+ @Override
+ public void sampleSequenceLength(long length, Embedder.Context ctx) { }
+
}
}