diff options
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 12 |
1 files changed, 1 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index b17ee89be7f..17b63fb1c7d 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -103,17 +103,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - Tensor.Builder builder = Tensor.Builder.of(tensorType); - - // Mean pooling implementation - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - - Tensor result = builder.build(); + var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); return normalize ? normalize(result, tensorType) : result; } |