summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-11-06 17:46:19 +0000
committerArne Juul <arnej@yahooinc.com>2023-11-10 09:55:58 +0000
commit27c62ead49c88382f128034c48e22d77a83f8104 (patch)
treee5f1114e862ad8ad9ae3e57d31cdf66be9056587 /model-integration
parent9fcbb132148e858173170df90a502d7c3ca2d5d5 (diff)
add simple expandBitTensor function
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java23
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java21
2 files changed, 35 insertions, 9 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 5a9fe34ef3d..a3273abff57 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -10,6 +10,10 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -113,7 +117,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
- if(!verifyTensorType(tensorType)) {
+ if (!verifyTensorType(tensorType)) {
throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " +
"Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
}
@@ -131,7 +135,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
protected Tensor embedQuery(String text, Context context, TensorType tensorType) {
- if(tensorType.valueType() == TensorType.Value.INT8)
+ if (tensorType.valueType() == TensorType.Value.INT8)
throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type");
long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document.
@@ -173,7 +177,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
- if(dims != result.shape()[1]) {
+ if (dims != result.shape()[1]) {
throw new IllegalArgumentException("Token dimensionality does not" +
" match indexed dimensionality of " + dims);
}
@@ -213,7 +217,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
Tensor contextualEmbeddings;
int retainedTokens = inputIds.size() -1; //Do not retain last PAD
- if(tensorType.valueType() == TensorType.Value.INT8) {
+ if (tensorType.valueType() == TensorType.Value.INT8) {
contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens);
} else {
contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens);
@@ -228,7 +232,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
throw new IllegalArgumentException("Indexed tensor must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[1];
- if(resultDimensionality != wantedDimensionality) {
+ if (resultDimensionality != wantedDimensionality) {
throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality
+ " + dimensions into tensor with " + wantedDimensionality);
}
@@ -251,7 +255,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
throw new IllegalArgumentException("Indexed tensor must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[1];
- if(resultDimensionality/8 != wantedDimensionality) {
+ if (resultDimensionality != 8 * wantedDimensionality) {
throw new IllegalArgumentException("Not possible to pack " + resultDimensionality
+ " + dimensions into " + wantedDimensionality + " dimensions");
}
@@ -279,6 +283,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ public static Tensor expandBitTensor(Tensor packed) {
+ var unpacker = new UnpackBitsFromInt8(new ReferenceNode("input"), TensorType.Value.FLOAT, "big");
+ var context = new MapContext();
+ context.put("input", new TensorValue(packed));
+ return unpacker.evaluate(context).asTensor();
+ }
+
protected boolean verifyTensorType(TensorType target) {
return target.dimensions().size() == 2 &&
target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1;
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index a32170b0a63..5b6aa9a3fe7 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -8,6 +8,7 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -100,9 +101,23 @@ public class ColBertEmbedderTest {
return result;
}
- static void assertPackedRight(String numbers, TensorType destination,String expected, int size) {
- Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size);
- assertEquals(expected,packed.toString());
+ static void assertPackedRight(String numbers, TensorType destination, String expected, int size) {
+ var in = (IndexedTensor) Tensor.from(numbers);
+ Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size);
+ assertEquals(expected, packed.toString());
+ Tensor unpacked = ColBertEmbedder.expandBitTensor(packed);
+ assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue());
+ for (int dOuter = 0; dOuter < size; dOuter++) {
+ for (int dInner = 0; dInner < in.shape()[1]; dInner++) {
+ var addr = TensorAddress.of(dOuter, dInner);
+ double oldVal = in.get(addr);
+ if (oldVal > 0) {
+ assertEquals(unpacked.get(addr), 1.0, 0.0);
+ } else {
+ assertEquals(unpacked.get(addr), 0.0, 0.0);
+ }
+ }
+ }
}
static final Embedder embedder;