summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java21
1 files changed, 18 insertions, 3 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index a32170b0a63..5b6aa9a3fe7 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -8,6 +8,7 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -100,9 +101,23 @@ public class ColBertEmbedderTest {
return result;
}
- static void assertPackedRight(String numbers, TensorType destination,String expected, int size) {
- Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size);
- assertEquals(expected,packed.toString());
+ static void assertPackedRight(String numbers, TensorType destination, String expected, int size) {
+ var in = (IndexedTensor) Tensor.from(numbers);
+ Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size);
+ assertEquals(expected, packed.toString());
+ Tensor unpacked = ColBertEmbedder.expandBitTensor(packed);
+ assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
+ for (int dOuter = 0; dOuter < size; dOuter++) {
+ for (int dInner = 0; dInner < in.shape()[1]; dInner++) {
+ var addr = TensorAddress.of(dOuter, dInner);
+ double oldVal = in.get(addr);
+ if (oldVal > 0) {
+ assertEquals(unpacked.get(addr), 1.0, 0.0);
+ } else {
+ assertEquals(unpacked.get(addr), 0.0, 0.0);
+ }
+ }
+ }
}
static final Embedder embedder;