summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-09-26 14:14:18 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2023-09-26 14:14:18 +0200
commit01deefc0c007995573c5564be7aa4d0ce1e01203 (patch)
treeb4d009b496e5f14b91f0c7f221a378e3ca916bed /model-integration/src/test
parent4231e6077a18b6fdf96ac899a7301882ef50d742 (diff)
Don't index PAD and re-factoring
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java18
1 files changed, 9 insertions, 9 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 8516f6e6689..4e398f7245d 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -31,7 +31,7 @@ public class ColBertEmbedderTest {
"[1, 1, 1, 1, 1, 1, 1, 1]" +
"]",
TensorType.fromSpec("tensor<int8>(dt{},x[1])"),
- "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}"
+ "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6
);
assertPackedRight(
"" +
@@ -41,7 +41,7 @@ public class ColBertEmbedderTest {
"[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" +
"]",
TensorType.fromSpec("tensor<int8>(dt{},x[2])"),
- "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}"
+ "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2
);
}
@@ -75,18 +75,18 @@ public class ColBertEmbedderTest {
}
String text = sb.toString();
Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
- assertEquals(512*128,fullFloat.size());
+ assertEquals(511*128,fullFloat.size());
Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext);
assertEquals(32*128,query.size());
Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext);
- assertEquals(512*16,binaryRep.size());
+ assertEquals(511*16,binaryRep.size());
Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext);
- // 4 tokens, 16 bytes each = 64 bytes
- //because of CLS, special, sequence, SEP
- assertEquals(4*16,shortDoc.size());;
+ // 3 tokens, 16 bytes each = 48 bytes
+ //CLS [unused1] sequence
+ assertEquals(3*16,shortDoc.size());;
}
static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) {
@@ -100,8 +100,8 @@ public class ColBertEmbedderTest {
return result;
}
- static void assertPackedRight(String numbers, TensorType destination,String expected) {
- Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination);
+ static void assertPackedRight(String numbers, TensorType destination,String expected, int size) {
+ Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size);
assertEquals(expected,packed.toString());
}