diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index a32170b0a63..5b6aa9a3fe7 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -8,6 +8,7 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -100,9 +101,23 @@ public class ColBertEmbedderTest { return result; } - static void assertPackedRight(String numbers, TensorType destination,String expected, int size) { - Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size); - assertEquals(expected,packed.toString()); + static void assertPackedRight(String numbers, TensorType destination, String expected, int size) { + var in = (IndexedTensor) Tensor.from(numbers); + Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size); + assertEquals(expected, packed.toString()); + Tensor unpacked = ColBertEmbedder.expandBitTensor(packed); + assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue()); + for (int dOuter = 0; dOuter < size; dOuter++) { + for (int dInner = 0; dInner < in.shape()[1]; dInner++) { + var addr = TensorAddress.of(dOuter, dInner); + double oldVal = in.get(addr); + if (oldVal > 0) { + assertEquals(unpacked.get(addr), 1.0, 0.0); + } else { + assertEquals(unpacked.get(addr), 0.0, 0.0); + } + } + } } static final Embedder embedder; |