aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@vespa.ai>2024-04-10 09:14:30 +0200
committerGitHub <noreply@github.com>2024-04-10 09:14:30 +0200
commit4d0144a4d249df6cce37539cba13969e9fd4ca4f (patch)
tree6478d0617dea7b6469a1c269cb54ccad36290095
parent8db9ee454f4ae9c677fdf9382fcb51139fbc263d (diff)
parent4d233b5379b8dc4b94901f8df8acda0a6f2c4420 (diff)
Merge pull request #30809 from vespa-engine/jobergum/add-context-caching
Add onnx output caching to embedder (allow different post-processing of model outputs)
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java6
-rw-r--r--linguistics/abi-spec.json5
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java21
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java82
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java118
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java153
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java (renamed from model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java)68
7 files changed, 304 insertions, 149 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
index ba07fc00ca8..cdd0c11baac 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ExecutionContext.java
@@ -22,7 +22,7 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
private final FieldValueAdapter adapter;
private FieldValue value;
private Language language;
- private final Map<String, Object> cache = LazyMap.newHashMap();
+ private final Map<Object, Object> cache = LazyMap.newHashMap();
public ExecutionContext() {
this(null);
@@ -125,12 +125,12 @@ public class ExecutionContext implements FieldTypeAdapter, FieldValueAdapter {
}
/** Returns a cached value, or null if not present. */
- public Object getCachedValue(String key) {
+ public Object getCachedValue(Object key) {
return cache.get(key);
}
/** Returns a mutable reference to the cache of this. */
- public Map<String, Object> getCache() {
+ public Map<Object, Object> getCache() {
return cache;
}
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 9f91c32cf62..a4adacc5905 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -346,8 +346,9 @@
"public com.yahoo.language.process.Embedder$Context setDestination(java.lang.String)",
"public java.lang.String getEmbedderId()",
"public com.yahoo.language.process.Embedder$Context setEmbedderId(java.lang.String)",
- "public void putCachedValue(java.lang.String, java.lang.Object)",
- "public java.lang.Object getCachedValue(java.lang.String)"
+ "public void putCachedValue(java.lang.Object, java.lang.Object)",
+ "public java.lang.Object getCachedValue(java.lang.Object)",
+ "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)"
],
"fields" : [ ]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index 2ab2de303c2..989edcdb18a 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -7,10 +7,10 @@ import com.yahoo.language.Language;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.function.Supplier;
/**
* An embedder converts a text string to a tensor
@@ -73,9 +73,10 @@ public interface Embedder {
*/
@Beta
interface Runtime {
- /** Sample latency metric for embedding */
+
+ /** Add a sample embedding latency to this */
void sampleEmbeddingLatency(double millis, Context ctx);
- /** Sample sequence length metric for embedding */
+ /** Add a sample embedding length to this */
void sampleSequenceLength(long length, Context ctx);
static Runtime testInstance() {
@@ -91,7 +92,7 @@ public interface Embedder {
private Language language = Language.UNKNOWN;
private String destination;
private String embedderId = "unknown";
- private final Map<String, Object> cache;
+ private final Map<Object, Object> cache;
public Context(String destination) {
this(destination, LazyMap.newHashMap());
@@ -101,7 +102,7 @@ public interface Embedder {
* @param destination the name of the recipient of this tensor
* @param cache a cache shared between all embed invocations for a single request
*/
- public Context(String destination, Map<String, Object> cache) {
+ public Context(String destination, Map<Object, Object> cache) {
this.destination = destination;
this.cache = Objects.requireNonNull(cache);
}
@@ -153,15 +154,21 @@ public interface Embedder {
return this;
}
- public void putCachedValue(String key, Object value) {
+ public void putCachedValue(Object key, Object value) {
cache.put(key, value);
}
/** Returns a cached value, or null if not present. */
- public Object getCachedValue(String key) {
+ public Object getCachedValue(Object key) {
return cache.get(key);
}
+ /** Returns the cached value, or computes and caches it if not present. */
+ @SuppressWarnings("unchecked")
+ public <T> T computeCachedValueIfAbsent(Object key, Supplier<? extends T> supplier) {
+ return (T) cache.computeIfAbsent(key, __ -> supplier.get());
+ }
+
}
class FailingEmbedder implements Embedder {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index f43f3834a65..2fd8e312a7e 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -149,7 +149,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
protected TransformerInput buildTransformerInput(List<Long> tokens, int maxTokens, boolean isQuery) {
- if(!isQuery) {
+ if (!isQuery) {
tokens = tokens.stream().filter(token -> !skipTokens.contains(token)).toList();
}
List<Long> inputIds = new ArrayList<>(maxTokens);
@@ -172,7 +172,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
attentionMask.add((long) 1);
for (int i = 0; i < padding; i++)
- attentionMask.add((long) 0);//Do not attend to mask paddings
+ attentionMask.add((long) 0); // Do not attend to mask paddings
return new TransformerInput(inputIds, attentionMask);
}
@@ -181,56 +181,44 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (tensorType.valueType() == TensorType.Value.INT8)
throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
+ EmbeddingResult result = lookupOrEvaluate(context, text, true);
+ return toFloatTensor((IndexedTensor)result.outputs.get(outputName), tensorType, result.inputIdSize);
+ }
- var start = System.nanoTime();
- var encoding = tokenizer.encode(text, context.getLanguage());
- runtime.sampleSequenceLength(encoding.ids().size(), context);
-
- TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true);
-
- Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
- Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
-
- var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
- attentionMaskName, attentionMaskTensor.expand("d0"));
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings;
+ protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
+ EmbeddingResult result = lookupOrEvaluate(context, text, false);
+ var modelOutput = (IndexedTensor)result.outputs.get(outputName);
+ if (tensorType.valueType() == TensorType.Value.INT8)
+ return toBitTensor(modelOutput, tensorType, result.inputIdSize);
+ else
+ return toFloatTensor(modelOutput, tensorType, result.inputIdSize);
+ }
- int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
- if (dims != result.shape()[2]) {
- throw new IllegalArgumentException("Token vector dimensionality does not" +
- " match indexed dimensionality of " + dims);
- }
- Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size());
- runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return resultTensor;
+ /**
+ * Evaluate the embedding model if the result is not present in the context cache.
+ *
+ * @param context the context accompanying the request
+ * @param text the text that is embedded
+ * @return the model output
+ */
+ protected EmbeddingResult lookupOrEvaluate(Context context, String text, boolean isQuery) {
+ var key = new EmbedderCacheKey(context.getEmbedderId(), text);
+ return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text, isQuery));
}
- protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
+ private EmbeddingResult evaluate(Context context, String text, boolean isQuery) {
var start = System.nanoTime();
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
-
- TransformerInput input = buildTransformerInput(encoding.ids(), maxDocumentTokens, false);
+ TransformerInput input = buildTransformerInput(encoding.ids(), isQuery ? maxQueryTokens : maxDocumentTokens, isQuery);
Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
-
- var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
+ var inputs = Map.of(inputIdsName,
+ inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
-
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings;
- Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens.
- if (tensorType.valueType() == TensorType.Value.INT8) {
- contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
- } else {
- contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens);
- }
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return contextualEmbeddings;
+ return new EmbeddingResult(input.inputIds.size(), outputs);
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
@@ -241,13 +229,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[2];
- if (resultDimensionality != wantedDimensionality) {
+ if (wantedDimensionality > resultDimensionality) {
throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality +
" dimensions into tensor with " + wantedDimensionality);
}
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
- for (int d = 0; d < resultDimensionality; d++) {
+ for (int d = 0; d < wantedDimensionality; d++) {
var value = result.get(0,token,d); // batch, sequence token, dimension
builder.cell(TensorAddress.of(token,d),value);
}
@@ -265,8 +253,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (size != 1)
throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
+ //Allow using the first n float dimensions to pack into int8
+ int floatDimensionality = 8 * wantedDimensionality;
int resultDimensionality = (int)result.shape()[2];
- if (resultDimensionality != 8 * wantedDimensionality) {
+ if (floatDimensionality > resultDimensionality) {
throw new IllegalArgumentException("Not possible to pack " + resultDimensionality +
" + dimensions into " + wantedDimensionality + " dimensions");
}
@@ -274,7 +264,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
for (int token = 0; token < nTokens; token++) {
BitSet bitSet = new BitSet(8);
int key = 0;
- for (int d = 0; d < result.shape()[2]; d++) {
+ for (int d = 0; d < floatDimensionality; d++) {
var value = result.get(0, token, d); // batch, sequence token, dimension
int bitIndex = 7 - (d % 8);
if (value > 0.0) {
@@ -319,4 +309,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ record EmbedderCacheKey(String embedderId, Object embeddedValue) { }
+
+ record EmbeddingResult(int inputIdSize, Map<String, Tensor> outputs) { }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 169648967d7..20d8b6362d3 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -104,59 +104,23 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer.close();
}
+ @SuppressWarnings("unchecked")
@Override
- public Tensor embed(String s, Context context, TensorType tensorType) {
- var start = System.nanoTime();
- var encoding = tokenizer.encode(s, context.getLanguage());
- runtime.sampleSequenceLength(encoding.ids().size(), context);
- Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
- Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
- Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1");
-
-
- Map<String, Tensor> inputs;
- if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) {
- inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
- attentionMaskName, attentionMask.expand("d0"));
- } else {
- inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
- attentionMaskName, attentionMask.expand("d0"),
- tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ public Tensor embed(String text, Context context, TensorType tensorType) {
+ if (tensorType.dimensions().size() != 1) {
+ throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension.");
}
-
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName);
- long[] resultShape = tokenEmbeddings.shape();
- //shape batch, sequence, embedding dimensionality
- if (resultShape.length != 3) {
- throw new IllegalArgumentException("" +
- "Expected 3 output dimensions for output name '" +
- outputName + "': [batch, sequence, embedding], got " + resultShape.length);
+ if (!tensorType.dimensions().get(0).isIndexed()) {
+ throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed.");
}
- Tensor result;
+ var embeddingResult = lookupOrEvaluate(context, text);
+ IndexedTensor tokenEmbeddings = embeddingResult.output;
if (tensorType.valueType() == TensorType.Value.INT8) {
- long outputDimensions = resultShape[2];
- long targetDim = tensorType.dimensions().get(0).size().get();
-
- if(targetDim * 8 > outputDimensions) {
- throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
- }
- //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output
- long firstDimensions = 8 * targetDim;
- String name = tensorType.indexedSubtype().dimensions().get(0).name();
- //perform pooling and normalizing using floating point embeddings before binarizing
- //using the firstDimensions as the target dimensionality
- TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build();
- result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask);
- result = normalize? normalize(result, poolingType) : result;
- result = binarize((IndexedTensor) result, tensorType);
-
- } else { // regular floating points embeddings
- result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
- result = normalize ? normalize(result, tensorType) : result;
+ return binaryQuantization(embeddingResult, tensorType);
+ } else {
+ Tensor result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, embeddingResult.attentionMask);
+ return normalize ? normalize(result, tensorType) : result;
}
- runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
- return result;
}
Tensor normalize(Tensor embedding, TensorType tensorType) {
@@ -178,6 +142,61 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ private HuggingFaceEmbedder.HFEmbeddingResult lookupOrEvaluate(Context context, String text) {
+ var key = new HFEmbedderCacheKey(context.getEmbedderId(), text);
+ return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text));
+ }
+
+ private HuggingFaceEmbedder.HFEmbeddingResult evaluate(Context context, String text) {
+ var start = System.nanoTime();
+ var encoding = tokenizer.encode(text, context.getLanguage());
+ runtime.sampleSequenceLength(encoding.ids().size(), context);
+ Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
+ Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
+ Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1");
+
+ Map<String, Tensor> inputs;
+ if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"));
+ } else {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"),
+ tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ }
+ IndexedTensor tokenEmbeddings = (IndexedTensor) evaluator.evaluate(inputs).get(outputName);
+ long[] resultShape = tokenEmbeddings.shape();
+ //shape batch, sequence, embedding dimensionality
+ if (resultShape.length != 3) {
+ throw new IllegalArgumentException("" +
+ "Expected 3 output dimensions for output name '" +
+ outputName + "': [batch, sequence, embedding], got " + resultShape.length);
+ }
+ runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
+ return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId());
+ }
+
+ private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType tensorType) {
+ long outputDimensions = embeddingResult.output().shape()[2];
+ long targetDim = tensorType.dimensions().get(0).size().get();
+ //🪆 flexibility - packing only the first 8*targetDim float values from the model output
+ long floatDimensions = 8 * targetDim;
+ if(floatDimensions > outputDimensions) {
+ throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
+ }
+ //perform pooling and normalizing using float version before binary quantization
+ TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).
+ indexed(tensorType.indexedSubtype().dimensions().get(0).name(),
+ floatDimensions).build();
+ Tensor result = poolingStrategy.toSentenceEmbedding(poolingType, embeddingResult.output(), embeddingResult.attentionMask());
+ result = normalize? normalize(result, poolingType) : result;
+ result = binarize((IndexedTensor) result, tensorType);
+ return result;
+ }
+
+ /**
+ * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
+ */
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);
BitSet bitSet = new BitSet(8);
@@ -211,6 +230,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
-
+ protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, String embedderId) {}
+ protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { }
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index be75c4d3351..f6216e4149c 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -61,27 +61,94 @@ public class ColBertEmbedderTest {
TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
"tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
);
+ assertPackedRight(
+ "" +
+ "tensor<float>(d0[1],d1[2],d2[16]):" +
+ "[[" +
+ "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," +
+ "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
+ "]]",
+ TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
+ "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0}",2
+ );
+ }
+
+ @Test
+ public void testCachingFloat() {
+ int initialEmbeddingsDone = runtime.embeddingsDone;
+ var context = new Embedder.Context("schema.indexing");
+
+ var input = "This is a test string to embed";
+ var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
+ assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone);
+
+ var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])"));
+ assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone);
+
+ assertNotEquals(t1,t2);
+ for(int token = 0; token < 7; token ++) {
+ for(int dim = 0; dim < 4; dim++) { // the four first should be equal
+ assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
+ }
+ }
+ // t2 only has 4 dimensions so this should be out of bounds which returns 0
+ assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6);
+
+ input = "This is a different test string to embed";
+ embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
+ assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone);
+ }
+
+ @Test
+ public void testCachingInt() {
+ int initialEmbeddingsDone = runtime.embeddingsDone;
+ var context = new Embedder.Context("schema.indexing");
+
+ var input = "This is a test string to embed";
+ var t1 = (MixedTensor) embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[8])"));
+ assertEquals(initialEmbeddingsDone + 1, runtime.embeddingsDone);
+
+ var t2 = (MixedTensor)embedder.embed(input, context, TensorType.fromSpec("tensor<int8>(dt{},x[4])"));
+ assertEquals("Cached value was used", initialEmbeddingsDone + 1, runtime.embeddingsDone);
+
+ assertNotEquals(t1, t2);
+ for(int token = 0; token < 7; token ++) {
+ for(int dim = 0; dim < 4; dim++) { // the four first should be equal
+ assertEquals(t1.get(TensorAddress.of(token,dim)), t2.get(TensorAddress.of(token,dim)), 1e-6);
+ }
+ }
+ // t2 only has 4 dimensions so this should be out of bounds which returns 0
+ assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6);
+
+ input = "This is a different test string to embed";
+ embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
+ assertEquals(initialEmbeddingsDone + 2, runtime.embeddingsDone);
}
+
@Test
public void testEmbedder() {
- assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext);
- assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext);
- assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext);
+ var indexingContext = new Embedder.Context("schema.indexing");
+ assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext,128);
+ assertEmbed("tensor<float>(dt{},x[64])", "this is a document", indexingContext,64);
- assertThrows(IllegalArgumentException.class, () -> {
- // throws because int8 is not supported for query context
- assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext);
- });
+ assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext,16);
+ assertEmbed("tensor<int8>(dt{},x[8])", "this is a document", indexingContext,8);
+ assertEmbed("tensor<int8>(dt{},x[4])", "this is a document", indexingContext,4);
+ assertEmbed("tensor<int8>(dt{},x[3])", "this is a document", indexingContext,3);
+
+ var queryContext = new Embedder.Context("query(qt{})");
+ assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext,128);
+ assertEmbed("tensor<float>(qt{},x[64])", "this is a query", queryContext,64);
assertThrows(IllegalArgumentException.class, () -> {
- // throws because 16 is less than model output (128) and we want float
- assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext);
+ // throws because int8 is not supported for query context
+ assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext,16);
});
assertThrows(IllegalArgumentException.class, () -> {
- // throws because 128/8 does not fit into 15
- assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext);
+ // throws because 8*32 is larger than (128)
+ assertEmbed("tensor<int8>(qt{},x[32])", "this is a query", queryContext,32);
});
}
@@ -130,26 +197,32 @@ public class ColBertEmbedderTest {
}
@Test
- public void testLenghtLimits() {
+ public void testLengthLimits() {
StringBuilder sb = new StringBuilder();
for(int i = 0; i < 1024; i++) {
sb.append("annoyance");
sb.append(" ");
}
String text = sb.toString();
- Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
- assertEquals(512*128,fullFloat.size());
+ var indexingContext = new Embedder.Context("schema.indexing");
- Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
- assertEquals(32*128,query.size());
+ Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext,128);
+ assertEquals(512*128,fullFloat.size());
- Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
+ Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext,16);
assertEquals(512*16,binaryRep.size());
- Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
+ Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext,16);
// 4 tokens, 16 bytes each = 64 bytes
//CLS [unused1] sequence
assertEquals(4*16,shortDoc.size());;
+
+ var queryContext = new Embedder.Context("query(qt{})");
+ Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext,128);
+ assertEquals(32*128,query.size());
+
+ Tensor shortQuery = assertEmbed("tensor<float>(dt{},x[64])", text, queryContext,64);
+ assertEquals(32*64,shortQuery.size());
}
@Ignore
@@ -163,18 +236,19 @@ public class ColBertEmbedderTest {
long now = System.currentTimeMillis();
int n = 1000;
for (int i = 0; i < n; i++) {
- assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
+ assertEmbed("tensor<float>(dt{},x[128])", text, new Embedder.Context("schema.indexing"),128);
}
long elapsed = (System.currentTimeMillis() - now);
System.out.println("Elapsed time: " + elapsed + " ms");
}
- static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
+ static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context, int dimSize) {
TensorType destType = TensorType.fromSpec(tensorSpec);
Tensor result = embedder.embed(text, context, destType);
assertEquals(destType,result.type());
MixedTensor mixedTensor = (MixedTensor) result;
- if (context == queryContext) {
+ assertEquals(dimSize,mixedTensor.denseSubspaceSize());
+ if (context.getDestination().startsWith("query")) {
assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size());
}
return result;
@@ -182,12 +256,12 @@ public class ColBertEmbedderTest {
static void assertPackedRight(String numbers, TensorType destination, String expected, int size) {
var in = (IndexedTensor) Tensor.from(numbers);
+ int targetDim = destination.indexedSubtype().dimensions().get(0).size().get().intValue();
Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size);
assertEquals(expected, packed.toString());
Tensor unpacked = ColBertEmbedder.expandBitTensor(packed);
- assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
for (int dOuter = 0; dOuter < size; dOuter++) {
- for (int dInner = 0; dInner < in.shape()[2]; dInner++) {
+ for (int dInner = 0; dInner < targetDim*8; dInner++) {
var addr = TensorAddress.of(dOuter, dInner);
double oldVal = in.get(TensorAddress.of(0,dOuter, dInner));
if (oldVal > 0) {
@@ -200,19 +274,16 @@ public class ColBertEmbedderTest {
}
static final ColBertEmbedder embedder;
-
static final ColBertEmbedder multiLingualEmbedder;
- static final Embedder.Context indexingContext;
- static final Embedder.Context queryContext;
+ static final CountingRuntime runtime;
static {
- indexingContext = new Embedder.Context("schema.indexing");
- queryContext = new Embedder.Context("query(qt)");
- embedder = getEmbedder();
- multiLingualEmbedder = getMultiLingualEmbedder();
+ runtime = new CountingRuntime();
+ embedder = createEmbedder(runtime);
+ multiLingualEmbedder = getMultiLingualEmbedder(runtime);
}
- private static ColBertEmbedder getEmbedder() {
+ private static ColBertEmbedder createEmbedder(Embedder.Runtime runtime) {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
@@ -220,10 +291,10 @@ public class ColBertEmbedderTest {
builder.tokenizerPath(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
builder.transformerGpuDevice(-1);
- return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ return new ColBertEmbedder(new OnnxRuntime(), runtime, builder.build());
}
- private static ColBertEmbedder getMultiLingualEmbedder() {
+ private static ColBertEmbedder getMultiLingualEmbedder(Embedder.Runtime runtime) {
String vocabPath = "src/test/models/onnx/transformer/sentence_piece_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
@@ -239,7 +310,21 @@ public class ColBertEmbedderTest {
builder.queryTokenId(3);
builder.documentTokenId(4);
- return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+
+ private static class CountingRuntime implements Embedder.Runtime {
+
+ int embeddingsDone = 0;
+
+ @Override
+ public void sampleEmbeddingLatency(double millis, Embedder.Context ctx) {
+ embeddingsDone++;
+ }
+
+ @Override
+ public void sampleSequenceLength(long length, Embedder.Context ctx) { }
+
}
}
diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
index 1ce1d955b00..d504d77cc9b 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedderTest.java
@@ -1,7 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.embedding;
+package ai.vespa.embedding.huggingface;
+
-import ai.vespa.embedding.huggingface.HuggingFaceEmbedder;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
@@ -12,10 +12,10 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorAddress;
import org.junit.Test;
+import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.*;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -26,7 +26,6 @@ public class HuggingFaceEmbedderTest {
static HuggingFaceEmbedder embedder = getEmbedder();
static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder();
- static Embedder.Context context = new Embedder.Context("schema.indexing");
@Test
public void testBinarization() {
@@ -48,16 +47,48 @@ public class HuggingFaceEmbedderTest {
private void assertPackRight(String input, String expected, TensorType type) {
Tensor inputTensor = Tensor.from(input);
Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type);
- assertEquals(expected.toString(), result.toString());
- //Verify against what is done in ranking with unpack_bits
+ assertEquals(expected, result.toString());
+ //Verify that the unpack_bits ranking feature produce compatible output
Tensor unpacked = expandBitTensor(result);
assertEquals(inputTensor.toString(), unpacked.toString());
}
@Test
+ public void testCaching() {
+ var context = new Embedder.Context("schema.indexing");
+ var myEmbedderId = "my-hf-embedder";
+ context.setEmbedderId(myEmbedderId);
+
+ var input = "This is a test string to embed";
+ Tensor result = embedder.embed(input, context,TensorType.fromSpec("tensor<float>(x[8])"));
+ HuggingFaceEmbedder.HFEmbedderCacheKey key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, input);
+ var modelOuput = context.getCachedValue(key);
+ assertNotNull(modelOuput);
+
+ Tensor binaryResult = embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(x[4])"));
+ var modelOuput2 = context.getCachedValue(key);
+ assertEquals(modelOuput, modelOuput2);
+ assertNotEquals(result, binaryResult);
+
+ var anotherInput = "This is a different test string to embed with the same embedder";
+ embedder.embed(anotherInput, context,TensorType.fromSpec("tensor<float>(x[4])"));
+ key = new HuggingFaceEmbedder.HFEmbedderCacheKey(myEmbedderId, anotherInput);
+ var modelOuput3 = context.getCachedValue(key);
+ assertNotEquals(modelOuput, modelOuput3);
+
+ //context cache is shared
+ var copyContext = context.copy();
+ var anotherEmbedderId = "another-hf-embedder";
+ copyContext.setEmbedderId(anotherEmbedderId);
+ key = new HuggingFaceEmbedder.HFEmbedderCacheKey(anotherEmbedderId, input);
+ assertNull(copyContext.getCachedValue(key));
+ embedder.embed(input, copyContext,TensorType.fromSpec("tensor<int8>(x[2])"));
+ assertNotEquals(modelOuput, copyContext.getCachedValue(key));
+ }
+ @Test
public void testEmbedder() {
+ var context = new Embedder.Context("schema.indexing");
String input = "This is a test";
-
Tensor expected = Tensor.from("tensor<float>(x[8]):[-0.666, 0.335, 0.227, 0.0919, -0.069, 0.323, 0.422, 0.270]");
Tensor result = embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
for(int i = 0; i < 8; i++) {
@@ -85,16 +116,33 @@ public class HuggingFaceEmbedderTest {
@Test
public void testEmbedderWithNormalization() {
String input = "This is a test";
-
+ var context = new Embedder.Context("schema.indexing");
Tensor result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
-
result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])")));
assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());
}
+ @Test
+ public void testThatWrongTensorTypeThrows() {
+ var context = new Embedder.Context("schema.indexing");
+ String input = "This is a test";
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor type is mapped
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{})")));
+ });
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor is 0d
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[0]")));
+ });
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor is 2d
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x{}, y[2])")));
+ });
+ }
+
private static HuggingFaceEmbedder getEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";