summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-04 09:15:10 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-04 09:15:10 +0200
commit531bc532c592703221e232d817850d802cdcfd11 (patch)
tree69d9a60d6a8ea48dbea331906e775589bce15dd7
parenta009cdd704f427282c3c9ed3b70a7caf9d536c7e (diff)
Support for dimensionality flexbility and caching onnx inference output using Context cache
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java60
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java124
2 files changed, 131 insertions, 53 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index f43f3834a65..2f4c0343bf6 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -181,34 +181,25 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (tensorType.valueType() == TensorType.Value.INT8)
throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
-
var start = System.nanoTime();
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
TransformerInput input = buildTransformerInput(encoding.ids(), maxQueryTokens, true);
-
Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1");
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings;
-
- int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
- if (dims != result.shape()[2]) {
- throw new IllegalArgumentException("Token vector dimensionality does not" +
- " match indexed dimensionality of " + dims);
- }
- Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size());
+ IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
+ Tensor resultTensor = toFloatTensor(modelOutput, tensorType, input.inputIds.size());
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
return resultTensor;
}
-
+ @SuppressWarnings("unchecked")
protected Tensor embedDocument(String text, Context context, TensorType tensorType) {
var start = System.nanoTime();
+
var encoding = tokenizer.encode(text, context.getLanguage());
runtime.sampleSequenceLength(encoding.ids().size(), context);
@@ -218,19 +209,34 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
attentionMaskName, attentionMaskTensor.expand("d0"));
-
- Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings;
- Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens.
+ IndexedTensor modelOutput = (IndexedTensor) evaluateIfNotPresent(inputs, context, text).get(outputName);
+ Tensor resultEmbeddings;
+ int maxTokens = input.inputIds.size();
if (tensorType.valueType() == TensorType.Value.INT8) {
- contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
+ resultEmbeddings = toBitTensor(modelOutput, tensorType, maxTokens);
} else {
- contextualEmbeddings = toFloatTensor(result, tensorType, maxTokens);
+ resultEmbeddings = toFloatTensor(modelOutput, tensorType, maxTokens);
}
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
- return contextualEmbeddings;
+ return resultEmbeddings;
+ }
+
+ /**
+ * Evaluate the model if the result is not present in the context cache.
+ * @param inputs the tensor inputs
+ * @param context the context accompanying the request, a singleton per embedder instance and request
+ * @param hashKey the key to the cached value
+ * @return the model output
+ */
+ @SuppressWarnings("unchecked")
+ protected Map<String, Tensor> evaluateIfNotPresent(Map<String, Tensor> inputs, Context context, String hashKey) {
+ if (context.getCachedValue(hashKey) == null) {
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
+ context.putCachedValue(hashKey, outputs);
+ return outputs;
+ } else {
+ return (Map<String, Tensor>) context.getCachedValue(hashKey);
+ }
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
@@ -241,13 +247,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[2];
- if (resultDimensionality != wantedDimensionality) {
+ if (wantedDimensionality > resultDimensionality) {
throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality +
" dimensions into tensor with " + wantedDimensionality);
}
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
- for (int d = 0; d < resultDimensionality; d++) {
+ for (int d = 0; d < wantedDimensionality; d++) {
var value = result.get(0,token,d); // batch, sequence token, dimension
builder.cell(TensorAddress.of(token,d),value);
}
@@ -265,8 +271,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (size != 1)
throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
+ //Allow using the first n float dimensions to pack into int8
+ int floatDimensionality = 8 * wantedDimensionality;
int resultDimensionality = (int)result.shape()[2];
- if (resultDimensionality != 8 * wantedDimensionality) {
+ if (floatDimensionality > resultDimensionality) {
throw new IllegalArgumentException("Not possible to pack " + resultDimensionality +
" + dimensions into " + wantedDimensionality + " dimensions");
}
@@ -274,7 +282,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
for (int token = 0; token < nTokens; token++) {
BitSet bitSet = new BitSet(8);
int key = 0;
- for (int d = 0; d < result.shape()[2]; d++) {
+ for (int d = 0; d < floatDimensionality; d++) {
var value = result.get(0, token, d); // batch, sequence token, dimension
int bitIndex = 7 - (d % 8);
if (value > 0.0) {
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index be75c4d3351..5fd0afad2c4 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -61,27 +61,94 @@ public class ColBertEmbedderTest {
TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
"tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
);
+ assertPackedRight(
+ "" +
+ "tensor<float>(d0[1],d1[2],d2[16]):" +
+ "[[" +
+ "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," +
+ "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
+ "]]",
+ TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
+ "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0}",2
+ );
+ }
+
+ @Test
+ public void testCachingFloat() {
+ var context = new Embedder.Context("schema.indexing");
+ var input = "This is a test string to embed";
+ var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
+ var modelOuput = context.getCachedValue(input);
+
+ var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[4])"));
+ var modelOuput2 = context.getCachedValue(input);
+ assertEquals(modelOuput, modelOuput2);
+
+ assertNotEquals(t1,t2);
+ for(int token = 0; token < 7; token ++) {
+ for(int dim = 0; dim < 4; dim++) { // the four first should be equal
+ assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
+ }
+ }
+ //t2 only has 4 dimensions so this should be out of bounds which returns 0
+ assertEquals(0, t2.get(TensorAddress.of(1,4)), 1e-6);
+
+ input = "This is a different test string to embed";
+ embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
+ var modelOuput3 = context.getCachedValue(input);
+ assertNotEquals(modelOuput, modelOuput3);
+ assertNotEquals(modelOuput2, modelOuput3);
+ }
+
+ @Test
+ public void testCachingInt() {
+ var context = new Embedder.Context("schema.indexing");
+ var input = "This is a test string to embed";
+ var t1 = (MixedTensor) embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[8])"));
+ var modelOuput = context.getCachedValue(input);
+
+ var t2 = (MixedTensor)embedder.embed(input, context,TensorType.fromSpec("tensor<int8>(dt{},x[4])"));
+ var modelOuput2 = context.getCachedValue(input);
+ assertEquals(modelOuput, modelOuput2);
+ assertNotEquals(t1,t2);
+ for(int token = 0; token < 7; token ++) {
+ for(int dim = 0; dim < 4; dim++) { // the four first should be equal
+ assertEquals(t1.get(TensorAddress.of(token,dim)),t2.get(TensorAddress.of(token,dim)), 1e-6);
+ }
+ }
+ //t2 only has 4 dimensions so this should be out of bounds which returns 0
+ assertEquals(0, t2.get(TensorAddress.of(0,4)), 1e-6);
+ input = "This is a different test string to embed";
+ embedder.embed(input, context,TensorType.fromSpec("tensor<float>(dt{},x[8])"));
+ var modelOuput3 = context.getCachedValue(input);
+ assertNotEquals(modelOuput, modelOuput3);
+ assertNotEquals(modelOuput2, modelOuput3);
}
+
@Test
public void testEmbedder() {
- assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext);
- assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext);
- assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext);
+ var indexingContext = new Embedder.Context("schema.indexing");
+ assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext,128);
+ assertEmbed("tensor<float>(dt{},x[64])", "this is a document", indexingContext,64);
- assertThrows(IllegalArgumentException.class, () -> {
- // throws because int8 is not supported for query context
- assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext);
- });
+ assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext,16);
+ assertEmbed("tensor<int8>(dt{},x[8])", "this is a document", indexingContext,8);
+ assertEmbed("tensor<int8>(dt{},x[4])", "this is a document", indexingContext,4);
+ assertEmbed("tensor<int8>(dt{},x[3])", "this is a document", indexingContext,3);
+
+ var queryContext = new Embedder.Context("query(qt{})");
+ assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext,128);
+ assertEmbed("tensor<float>(qt{},x[64])", "this is a query", queryContext,64);
assertThrows(IllegalArgumentException.class, () -> {
- // throws because 16 is less than model output (128) and we want float
- assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext);
+ // throws because int8 is not supported for query context
+ assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext,16);
});
assertThrows(IllegalArgumentException.class, () -> {
- // throws because 128/8 does not fit into 15
- assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext);
+ // throws because 8*32 is larger than (128)
+ assertEmbed("tensor<int8>(qt{},x[32])", "this is a query", queryContext,32);
});
}
@@ -130,26 +197,32 @@ public class ColBertEmbedderTest {
}
@Test
- public void testLenghtLimits() {
+ public void testLengthLimits() {
StringBuilder sb = new StringBuilder();
for(int i = 0; i < 1024; i++) {
sb.append("annoyance");
sb.append(" ");
}
String text = sb.toString();
- Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
- assertEquals(512*128,fullFloat.size());
+ var indexingContext = new Embedder.Context("schema.indexing");
- Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
- assertEquals(32*128,query.size());
+ Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext,128);
+ assertEquals(512*128,fullFloat.size());
- Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
+ Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext,16);
assertEquals(512*16,binaryRep.size());
- Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
+ Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext,16);
// 4 tokens, 16 bytes each = 64 bytes
//CLS [unused1] sequence
assertEquals(4*16,shortDoc.size());;
+
+ var queryContext = new Embedder.Context("query(qt{})");
+ Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext,128);
+ assertEquals(32*128,query.size());
+
+ Tensor shortQuery = assertEmbed("tensor<float>(dt{},x[64])", text, queryContext,64);
+ assertEquals(32*64,shortQuery.size());
}
@Ignore
@@ -163,18 +236,19 @@ public class ColBertEmbedderTest {
long now = System.currentTimeMillis();
int n = 1000;
for (int i = 0; i < n; i++) {
- assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
+ assertEmbed("tensor<float>(dt{},x[128])", text, new Embedder.Context("schema.indexing"),128);
}
long elapsed = (System.currentTimeMillis() - now);
System.out.println("Elapsed time: " + elapsed + " ms");
}
- static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
+ static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context, int dimSize) {
TensorType destType = TensorType.fromSpec(tensorSpec);
Tensor result = embedder.embed(text, context, destType);
assertEquals(destType,result.type());
MixedTensor mixedTensor = (MixedTensor) result;
- if (context == queryContext) {
+ assertEquals(dimSize,mixedTensor.denseSubspaceSize());
+ if (context.getDestination().startsWith("query")) {
assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size());
}
return result;
@@ -182,12 +256,12 @@ public class ColBertEmbedderTest {
static void assertPackedRight(String numbers, TensorType destination, String expected, int size) {
var in = (IndexedTensor) Tensor.from(numbers);
+ int targetDim = destination.indexedSubtype().dimensions().get(0).size().get().intValue();
Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size);
assertEquals(expected, packed.toString());
Tensor unpacked = ColBertEmbedder.expandBitTensor(packed);
- assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
for (int dOuter = 0; dOuter < size; dOuter++) {
- for (int dInner = 0; dInner < in.shape()[2]; dInner++) {
+ for (int dInner = 0; dInner < targetDim*8; dInner++) {
var addr = TensorAddress.of(dOuter, dInner);
double oldVal = in.get(TensorAddress.of(0,dOuter, dInner));
if (oldVal > 0) {
@@ -202,12 +276,8 @@ public class ColBertEmbedderTest {
static final ColBertEmbedder embedder;
static final ColBertEmbedder multiLingualEmbedder;
- static final Embedder.Context indexingContext;
- static final Embedder.Context queryContext;
static {
- indexingContext = new Embedder.Context("schema.indexing");
- queryContext = new Embedder.Context("query(qt)");
embedder = getEmbedder();
multiLingualEmbedder = getMultiLingualEmbedder();
}