diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 118 |
1 files changed, 69 insertions, 49 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 169648967d7..20d8b6362d3 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -104,59 +104,23 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer.close(); } + @SuppressWarnings("unchecked") @Override - public Tensor embed(String s, Context context, TensorType tensorType) { - var start = System.nanoTime(); - var encoding = tokenizer.encode(s, context.getLanguage()); - runtime.sampleSequenceLength(encoding.ids().size(), context); - Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); - Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); - Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); - - - Map<String, Tensor> inputs; - if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { - inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0")); - } else { - inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0"), - tokenTypeIdsName, tokenTypeIds.expand("d0")); + public Tensor embed(String text, Context context, TensorType tensorType) { + if (tensorType.dimensions().size() != 1) { + throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension."); } - - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName); - long[] resultShape = tokenEmbeddings.shape(); - //shape batch, sequence, embedding dimensionality - if (resultShape.length != 3) { - throw new IllegalArgumentException("" + - "Expected 3 output dimensions for output name '" + - outputName + "': [batch, sequence, embedding], got " + resultShape.length); + if (!tensorType.dimensions().get(0).isIndexed()) { + throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed."); } - Tensor result; + var embeddingResult = lookupOrEvaluate(context, text); + IndexedTensor tokenEmbeddings = embeddingResult.output; if (tensorType.valueType() == TensorType.Value.INT8) { - long outputDimensions = resultShape[2]; - long targetDim = tensorType.dimensions().get(0).size().get(); - - if(targetDim * 8 > outputDimensions) { - throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); - } - //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output - long firstDimensions = 8 * targetDim; - String name = tensorType.indexedSubtype().dimensions().get(0).name(); - //perform pooling and normalizing using floating point embeddings before binarizing - //using the firstDimensions as the target dimensionality - TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build(); - result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask); - result = normalize? normalize(result, poolingType) : result; - result = binarize((IndexedTensor) result, tensorType); - - } else { // regular floating points embeddings - result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); - result = normalize ? normalize(result, tensorType) : result; + return binaryQuantization(embeddingResult, tensorType); + } else { + Tensor result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, embeddingResult.attentionMask); + return normalize ? normalize(result, tensorType) : result; } - runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); - return result; } Tensor normalize(Tensor embedding, TensorType tensorType) { @@ -178,6 +142,61 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } + private HuggingFaceEmbedder.HFEmbeddingResult lookupOrEvaluate(Context context, String text) { + var key = new HFEmbedderCacheKey(context.getEmbedderId(), text); + return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text)); + } + + private HuggingFaceEmbedder.HFEmbeddingResult evaluate(Context context, String text) { + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); + Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); + Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); + + Map<String, Tensor> inputs; + if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0"), + tokenTypeIdsName, tokenTypeIds.expand("d0")); + } + IndexedTensor tokenEmbeddings = (IndexedTensor) evaluator.evaluate(inputs).get(outputName); + long[] resultShape = tokenEmbeddings.shape(); + //shape batch, sequence, embedding dimensionality + if (resultShape.length != 3) { + throw new IllegalArgumentException("" + + "Expected 3 output dimensions for output name '" + + outputName + "': [batch, sequence, embedding], got " + resultShape.length); + } + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return new HFEmbeddingResult(tokenEmbeddings, attentionMask, context.getEmbedderId()); + } + + private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddingResult, TensorType tensorType) { + long outputDimensions = embeddingResult.output().shape()[2]; + long targetDim = tensorType.dimensions().get(0).size().get(); + //🪆 flexibility - packing only the first 8*targetDim float values from the model output + long floatDimensions = 8 * targetDim; + if(floatDimensions > outputDimensions) { + throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s"); + } + //perform pooling and normalizing using float version before binary quantization + TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT). + indexed(tensorType.indexedSubtype().dimensions().get(0).name(), + floatDimensions).build(); + Tensor result = poolingStrategy.toSentenceEmbedding(poolingType, embeddingResult.output(), embeddingResult.attentionMask()); + result = normalize? normalize(result, poolingType) : result; + result = binarize((IndexedTensor) result, tensorType); + return result; + } + + /** + * Binary quantization of the embedding into a tensor of type int8 with the specified dimensions. + */ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { Tensor.Builder builder = Tensor.Builder.of(tensorType); BitSet bitSet = new BitSet(8); @@ -211,6 +230,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - + protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, String embedderId) {} + protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { } } |