summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java27
1 files changed, 16 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 0bee03a65af..8c39cc8c813 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -191,10 +191,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
attentionMaskName, attentionMaskTensor.expand("d0"));
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
+ IndexedTensor result = (IndexedTensor) tokenEmbeddings;
int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
- if (dims != result.shape()[1]) {
+ if (dims != result.shape()[2]) {
throw new IllegalArgumentException("Token vector dimensionality does not" +
" match indexed dimensionality of " + dims);
}
@@ -217,9 +217,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
- IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0");
+ IndexedTensor result = (IndexedTensor) tokenEmbeddings;
Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size() -1; //Do not retain last PAD
+ int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens.
if (tensorType.valueType() == TensorType.Value.INT8) {
contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
} else {
@@ -230,11 +230,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
+ if(result.shape().length != 3)
+ throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
- throw new IllegalArgumentException("Indexed tensor must have one dimension");
+ throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
- int resultDimensionality = (int)result.shape()[1];
+ int resultDimensionality = (int)result.shape()[2];
if (resultDimensionality != wantedDimensionality) {
throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality
+ " + dimensions into tensor with " + wantedDimensionality);
@@ -242,7 +244,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
for (int d = 0; d < resultDimensionality; d++) {
- var value = result.get(TensorAddress.of(token, d));
+ var value = result.get(0,token,d); // batch, sequence token, dimension
builder.cell(TensorAddress.of(token,d),value);
}
}
@@ -253,11 +255,14 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
if (type.valueType() != TensorType.Value.INT8)
throw new IllegalArgumentException("Only a int8 tensor type can be" +
" the destination of bit packing");
+ if(result.shape().length != 3)
+ throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
+
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
- throw new IllegalArgumentException("Indexed tensor must have one dimension");
+ throw new IllegalArgumentException("Target indexed sub-type must have one dimension");
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
- int resultDimensionality = (int)result.shape()[1];
+ int resultDimensionality = (int)result.shape()[2];
if (resultDimensionality != 8 * wantedDimensionality) {
throw new IllegalArgumentException("Not possible to pack " + resultDimensionality
+ " + dimensions into " + wantedDimensionality + " dimensions");
@@ -266,8 +271,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
for (int token = 0; token < nTokens; token++) {
BitSet bitSet = new BitSet(8);
int key = 0;
- for (int d = 0; d < result.shape()[1]; d++) {
- var value = result.get(TensorAddress.of(token, d));
+ for (int d = 0; d < result.shape()[2]; d++) {
+ var value = result.get(0, token, d); // batch, sequence token, dimension
int bitIndex = 7 - (d % 8);
if (value > 0.0) {
bitSet.set(bitIndex);