diff options
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 23 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 21 |
2 files changed, 35 insertions, 9 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 5a9fe34ef3d..a3273abff57 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -10,6 +10,10 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -113,7 +117,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { - if(!verifyTensorType(tensorType)) { + if (!verifyTensorType(tensorType)) { throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } @@ -131,7 +135,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } protected Tensor embedQuery(String text, Context context, TensorType tensorType) { - if(tensorType.valueType() == TensorType.Value.INT8) + if (tensorType.valueType() == TensorType.Value.INT8) throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document. @@ -173,7 +177,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if(dims != result.shape()[1]) { + if (dims != result.shape()[1]) { throw new IllegalArgumentException("Token dimensionality does not" + " match indexed dimensionality of " + dims); } @@ -213,7 +217,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); Tensor contextualEmbeddings; int retainedTokens = inputIds.size() -1; //Do not retain last PAD - if(tensorType.valueType() == TensorType.Value.INT8) { + if (tensorType.valueType() == TensorType.Value.INT8) { contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens); } else { contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens); @@ -228,7 +232,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Indexed tensor must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[1]; - if(resultDimensionality != wantedDimensionality) { + if (resultDimensionality != wantedDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " + dimensions into tensor with " + wantedDimensionality); } @@ -251,7 +255,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Indexed tensor must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[1]; - if(resultDimensionality/8 != wantedDimensionality) { + if (resultDimensionality != 8 * wantedDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); } @@ -279,6 +283,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + public static Tensor expandBitTensor(Tensor packed) { + var unpacker = new UnpackBitsFromInt8(new ReferenceNode("input"), TensorType.Value.FLOAT, "big"); + var context = new MapContext(); + context.put("input", new TensorValue(packed)); + return unpacker.evaluate(context).asTensor(); + } + protected boolean verifyTensorType(TensorType target) { return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index a32170b0a63..5b6aa9a3fe7 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -8,6 +8,7 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -100,9 +101,23 @@ public class ColBertEmbedderTest { return result; } - static void assertPackedRight(String numbers, TensorType destination,String expected, int size) { - Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size); - assertEquals(expected,packed.toString()); + static void assertPackedRight(String numbers, TensorType destination, String expected, int size) { + var in = (IndexedTensor) Tensor.from(numbers); + Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size); + assertEquals(expected, packed.toString()); + Tensor unpacked = ColBertEmbedder.expandBitTensor(packed); + assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue()); + for (int dOuter = 0; dOuter < size; dOuter++) { + for (int dInner = 0; dInner < in.shape()[1]; dInner++) { + var addr = TensorAddress.of(dOuter, dInner); + double oldVal = in.get(addr); + if (oldVal > 0) { + assertEquals(unpacked.get(addr), 1.0, 0.0); + } else { + assertEquals(unpacked.get(addr), 0.0, 0.0); + } + } + } } static final Embedder embedder; |