diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-08 14:39:16 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-08 14:39:16 +0200 |
commit | 794b62b71cc64e1ad2cb3a40865ff65653d4240f (patch) | |
tree | 24b493ecdbc37e2356fc1c9cb7553ea339c992f8 /model-integration | |
parent | c0652d7794a90e0afb593fc1a3db17c99606a808 (diff) |
Add missing wiring of pooling strategy
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 12 |
1 files changed, 1 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index b17ee89be7f..17b63fb1c7d 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -103,17 +103,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - Tensor.Builder builder = Tensor.Builder.of(tensorType); - - // Mean pooling implementation - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - - Tensor result = builder.build(); + var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); return normalize ? normalize(result, tensorType) : result; } |