summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-01-15 18:54:03 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-01-15 18:54:03 +0100
commit250aec1b7f1156ded4ef8eed2b4f029dafe4bc8a (patch)
treeba80f21c0bba5d9b6d8d4169c57479ff1360f947 /model-integration
parent348ba0774b8047aeb15d8f96c189991dac4180b1 (diff)
Avoid generic reduce and keep PAD token embedding
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java27
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java44
2 files changed, 47 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 0bee03a65af..8c39cc8c813 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -191,10 +191,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
attentionMaskName, attentionMaskTensor.expand("d0"));
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
+ IndexedTensor result = (IndexedTensor) tokenEmbeddings;
int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
- if (dims != result.shape()[1]) {
+ if (dims != result.shape()[2]) {
throw new IllegalArgumentException("Token vector dimensionality does not" +
" match indexed dimensionality of " + dims);
}
@@ -217,9 +217,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
+ IndexedTensor result = (IndexedTensor) tokenEmbeddings;
Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size() -1; //Do not retain last PAD
+ int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens.
if (tensorType.valueType() == TensorType.Value.INT8) {
contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
} else {
@@ -230,11 +230,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
+ if(result.shape().length != 3)
+ throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
- throw new IllegalArgumentException("Indexed tensor must have one dimension");
+ throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
- int resultDimensionality = (int)result.shape()[1];
+ int resultDimensionality = (int)result.shape()[2];
if (resultDimensionality != wantedDimensionality) {
throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality
+ " + dimensions into tensor with " + wantedDimensionality);
@@ -242,7 +244,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
for (int d = 0; d < resultDimensionality; d++) {
- var value = result.get(TensorAddress.of(token, d));
+ var value = result.get(0,token,d); // batch, sequence token, dimension
builder.cell(TensorAddress.of(token,d),value);
}
}
@@ -253,11 +255,14 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (type.valueType() != TensorType.Value.INT8)
throw new IllegalArgumentException("Only a int8 tensor type can be" +
" the destination of bit packing");
+ if(result.shape().length != 3)
+ throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
+
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
- throw new IllegalArgumentException("Indexed tensor must have one dimension");
+ throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
- int resultDimensionality = (int)result.shape()[1];
+ int resultDimensionality = (int)result.shape()[2];
if (resultDimensionality != 8 * wantedDimensionality) {
throw new IllegalArgumentException("Not possible to pack " + resultDimensionality
+ " + dimensions into " + wantedDimensionality + " dimensions");
@@ -266,8 +271,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
for (int token = 0; token < nTokens; token++) {
BitSet bitSet = new BitSet(8);
int key = 0;
- for (int d = 0; d < result.shape()[1]; d++) {
- var value = result.get(TensorAddress.of(token, d));
+ for (int d = 0; d < result.shape()[2]; d++) {
+ var value = result.get(0, token, d); // batch, sequence token, dimension
int bitIndex = 7 - (d % 8);
if (value > 0.0) {
bitSet.set(bitIndex);
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index f3682e45efc..0cae94c372a 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import org.junit.Ignore;
import org.junit.Test;
import java.util.List;
@@ -35,25 +36,25 @@ public class ColBertEmbedderTest {
public void testPacking() {
assertPackedRight(
"" +
- "tensor<float>(d1[6],d2[8]):" +
- "[" +
+ "tensor<float>(d0[1],d1[6],d2[8]):" +
+ "[[" +
"[0, 0, 0, 0, 0, 0, 0, 1]," +
"[0, 0, 0, 0, 0, 1, 0, 1]," +
"[0, 0, 0, 0, 0, 0, 1, 1]," +
"[0, 1, 1, 1, 1, 1, 1, 1]," +
"[1, 0, 0, 0, 0, 0, 0, 0]," +
"[1, 1, 1, 1, 1, 1, 1, 1]" +
- "]",
+ "]]",
TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
"tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6
);
assertPackedRight(
"" +
- "tensor<float>(d1[2],d2[16]):" +
- "[" +
+ "tensor<float>(d0[1],d1[2],d2[16]):" +
+ "[[" +
"[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," +
"[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
- "]",
+ "]]",
TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
"tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
);
@@ -133,18 +134,35 @@ public class ColBertEmbedderTest {
}
String text = sb.toString();
Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
- assertEquals(511*128,fullFloat.size());
+ assertEquals(512*128,fullFloat.size());
Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
assertEquals(32*128,query.size());
Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
- assertEquals(511*16,binaryRep.size());
+ assertEquals(512*16,binaryRep.size());
Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
- // 3 tokens, 16 bytes each = 48 bytes
+ // 4 tokens, 16 bytes each = 64 bytes
//CLS [unused1] sequence
- assertEquals(3*16,shortDoc.size());;
+ assertEquals(4*16,shortDoc.size());;
+ }
+
+ @Ignore
+ public void testPerf() {
+ StringBuilder sb = new StringBuilder();
+ for(int i = 0; i < 256; i++) {
+ sb.append("annoyance");
+ sb.append(" ");
+ }
+ String text = sb.toString();
+ Long now = System.currentTimeMillis();
+ int n = 1000;
+ for (int i = 0; i < n; i++) {
+ assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
+ }
+ Long elapsed = (System.currentTimeMillis() - now);
+ System.out.println("Elapsed time: " + elapsed + " ms");
}
static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
@@ -163,11 +181,11 @@ public class ColBertEmbedderTest {
Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size);
assertEquals(expected, packed.toString());
Tensor unpacked = ColBertEmbedder.expandBitTensor(packed);
- assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
+ assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
for (int dOuter = 0; dOuter < size; dOuter++) {
- for (int dInner = 0; dInner < in.shape()[1]; dInner++) {
+ for (int dInner = 0; dInner < in.shape()[2]; dInner++) {
var addr = TensorAddress.of(dOuter, dInner);
- double oldVal = in.get(addr);
+ double oldVal = in.get(TensorAddress.of(0,dOuter, dInner));
if (oldVal > 0) {
assertEquals(unpacked.get(addr), 1.0, 0.0);
} else {