diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 60 |
1 files changed, 55 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 35645deffa4..169648967d7 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -17,6 +17,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; +import java.util.BitSet; import java.util.List; import java.util.Map; import java.util.logging.Logger; @@ -124,18 +125,44 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { } Map<String, Tensor> outputs = evaluator.evaluate(inputs); - Tensor tokenEmbeddings = outputs.get(outputName); - var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); - var normalized = normalize ? normalize(result, tensorType) : result; + IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName); + long[] resultShape = tokenEmbeddings.shape(); + //shape batch, sequence, embedding dimensionality + if (resultShape.length != 3) { + throw new IllegalArgumentException("" + + "Expected 3 output dimensions for output name '" + + outputName + "': [batch, sequence, embedding], got " + resultShape.length); + } + Tensor result; + if (tensorType.valueType() == TensorType.Value.INT8) { + long outputDimensions = resultShape[2]; + long targetDim = tensorType.dimensions().get(0).size().get(); + + if(targetDim * 8 > outputDimensions) { + throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); + } + //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output + long firstDimensions = 8 * targetDim; + String name = tensorType.indexedSubtype().dimensions().get(0).name(); + //perform pooling and normalizing using floating point embeddings before binarizing + //using the firstDimensions as the target dimensionality + TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build(); + result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask); + result = normalize? normalize(result, poolingType) : result; + result = binarize((IndexedTensor) result, tensorType); + + } else { // regular floating points embeddings + result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); + result = normalize ? normalize(result, tensorType) : result; + } runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); - return normalized; + return result; } Tensor normalize(Tensor embedding, TensorType tensorType) { double sumOfSquares = 0.0; Tensor.Builder builder = Tensor.Builder.of(tensorType); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) { double item = embedding.get(TensorAddress.of(i)); sumOfSquares += item * item; @@ -151,6 +178,29 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { + Tensor.Builder builder = Tensor.Builder.of(tensorType); + BitSet bitSet = new BitSet(8); + int index = 0; + for (int d = 0; d < embedding.sizeAsInt(); d++) { + var value = embedding.get(d); + int bitIndex = 7 - (d % 8); + if (value > 0.0) { + bitSet.set(bitIndex); + } else { + bitSet.clear(bitIndex); + } + if ((d + 1) % 8 == 0) { + byte[] bytes = bitSet.toByteArray(); + byte packed = (bytes.length == 0) ? 0 : bytes[0]; + builder.cell(TensorAddress.of(index), packed); + index++; + bitSet = new BitSet(8); + } + } + return builder.build(); + } + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { int size = input.size(); TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); |