summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java44
1 files changed, 31 insertions, 13 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index f3682e45efc..0cae94c372a 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import org.junit.Ignore;
import org.junit.Test;
import java.util.List;
@@ -35,25 +36,25 @@ public class ColBertEmbedderTest {
public void testPacking() {
assertPackedRight(
"" +
- "tensor<float>(d1[6],d2[8]):" +
- "[" +
+ "tensor<float>(d0[1],d1[6],d2[8]):" +
+ "[[" +
"[0, 0, 0, 0, 0, 0, 0, 1]," +
"[0, 0, 0, 0, 0, 1, 0, 1]," +
"[0, 0, 0, 0, 0, 0, 1, 1]," +
"[0, 1, 1, 1, 1, 1, 1, 1]," +
"[1, 0, 0, 0, 0, 0, 0, 0]," +
"[1, 1, 1, 1, 1, 1, 1, 1]" +
- "]",
+ "]]",
TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
"tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6
);
assertPackedRight(
"" +
- "tensor<float>(d1[2],d2[16]):" +
- "[" +
+ "tensor<float>(d0[1],d1[2],d2[16]):" +
+ "[[" +
"[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," +
"[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
- "]",
+ "]]",
TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
"tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
);
@@ -133,18 +134,35 @@ public class ColBertEmbedderTest {
}
String text = sb.toString();
Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
- assertEquals(511*128,fullFloat.size());
+ assertEquals(512*128,fullFloat.size());
Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
assertEquals(32*128,query.size());
Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
- assertEquals(511*16,binaryRep.size());
+ assertEquals(512*16,binaryRep.size());
Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
- // 3 tokens, 16 bytes each = 48 bytes
+ // 4 tokens, 16 bytes each = 64 bytes
//CLS [unused1] sequence
- assertEquals(3*16,shortDoc.size());;
+ assertEquals(4*16,shortDoc.size());;
+ }
+
+ @Ignore
+ public void testPerf() {
+ StringBuilder sb = new StringBuilder();
+ for(int i = 0; i < 256; i++) {
+ sb.append("annoyance");
+ sb.append(" ");
+ }
+ String text = sb.toString();
+ Long now = System.currentTimeMillis();
+ int n = 1000;
+ for (int i = 0; i < n; i++) {
+ assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
+ }
+ Long elapsed = (System.currentTimeMillis() - now);
+ System.out.println("Elapsed time: " + elapsed + " ms");
}
static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
@@ -163,11 +181,11 @@ public class ColBertEmbedderTest {
Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size);
assertEquals(expected, packed.toString());
Tensor unpacked = ColBertEmbedder.expandBitTensor(packed);
- assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
+ assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
for (int dOuter = 0; dOuter < size; dOuter++) {
- for (int dInner = 0; dInner < in.shape()[1]; dInner++) {
+ for (int dInner = 0; dInner < in.shape()[2]; dInner++) {
var addr = TensorAddress.of(dOuter, dInner);
- double oldVal = in.get(addr);
+ double oldVal = in.get(TensorAddress.of(0,dOuter, dInner));
if (oldVal > 0) {
assertEquals(unpacked.get(addr), 1.0, 0.0);
} else {