diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 5a9fe34ef3d..a3273abff57 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -10,6 +10,10 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -113,7 +117,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { - if(!verifyTensorType(tensorType)) { + if (!verifyTensorType(tensorType)) { throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } @@ -131,7 +135,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } protected Tensor embedQuery(String text, Context context, TensorType tensorType) { - if(tensorType.valueType() == TensorType.Value.INT8) + if (tensorType.valueType() == TensorType.Value.INT8) throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document. @@ -173,7 +177,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if(dims != result.shape()[1]) { + if (dims != result.shape()[1]) { throw new IllegalArgumentException("Token dimensionality does not" + " match indexed dimensionality of " + dims); } @@ -213,7 +217,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); Tensor contextualEmbeddings; int retainedTokens = inputIds.size() -1; //Do not retain last PAD - if(tensorType.valueType() == TensorType.Value.INT8) { + if (tensorType.valueType() == TensorType.Value.INT8) { contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens); } else { contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens); @@ -228,7 +232,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Indexed tensor must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[1]; - if(resultDimensionality != wantedDimensionality) { + if (resultDimensionality != wantedDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " + dimensions into tensor with " + wantedDimensionality); } @@ -251,7 +255,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Indexed tensor must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[1]; - if(resultDimensionality/8 != wantedDimensionality) { + if (resultDimensionality != 8 * wantedDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); } @@ -279,6 +283,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + public static Tensor expandBitTensor(Tensor packed) { + var unpacker = new UnpackBitsFromInt8(new ReferenceNode("input"), TensorType.Value.FLOAT, "big"); + var context = new MapContext(); + context.put("input", new TensorValue(packed)); + return unpacker.evaluate(context).asTensor(); + } + protected boolean verifyTensorType(TensorType target) { return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; |