diff options
Diffstat (limited to 'model-integration/src/main')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 0bee03a65af..8c39cc8c813 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -191,10 +191,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { attentionMaskName, attentionMaskTensor.expand("d0")); Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + IndexedTensor result = (IndexedTensor) tokenEmbeddings; int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if (dims != result.shape()[1]) { + if (dims != result.shape()[2]) { throw new IllegalArgumentException("Token vector dimensionality does not" + " match indexed dimensionality of " + dims); } @@ -217,9 +217,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + IndexedTensor result = (IndexedTensor) tokenEmbeddings; Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size() -1; //Do not retain last PAD + int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens. if (tensorType.valueType() == TensorType.Value.INT8) { contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); } else { @@ -230,11 +230,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { + if(result.shape().length != 3) + throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); int size = type.indexedSubtype().dimensions().size(); if (size != 1) - throw new IllegalArgumentException("Indexed tensor must have one dimension"); + throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDimensionality = (int)result.shape()[1]; + int resultDimensionality = (int)result.shape()[2]; if (resultDimensionality != wantedDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " + dimensions into tensor with " + wantedDimensionality); @@ -242,7 +244,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { for (int d = 0; d < resultDimensionality; d++) { - var value = result.get(TensorAddress.of(token, d)); + var value = result.get(0,token,d); // batch, sequence token, dimension builder.cell(TensorAddress.of(token,d),value); } } @@ -253,11 +255,14 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (type.valueType() != TensorType.Value.INT8) throw new IllegalArgumentException("Only a int8 tensor type can be" + " the destination of bit packing"); + if(result.shape().length != 3) + throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); + int size = type.indexedSubtype().dimensions().size(); if (size != 1) - throw new IllegalArgumentException("Indexed tensor must have one dimension"); + throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDimensionality = (int)result.shape()[1]; + int resultDimensionality = (int)result.shape()[2]; if (resultDimensionality != 8 * wantedDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); @@ -266,8 +271,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; - for (int d = 0; d < result.shape()[1]; d++) { - var value = result.get(TensorAddress.of(token, d)); + for (int d = 0; d < result.shape()[2]; d++) { + var value = result.get(0, token, d); // batch, sequence token, dimension int bitIndex = 7 - (d % 8); if (value > 0.0) { bitSet.set(bitIndex); |