summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-08 14:39:16 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-08 14:39:16 +0200
commit794b62b71cc64e1ad2cb3a40865ff65653d4240f (patch)
tree24b493ecdbc37e2356fc1c9cb7553ea339c992f8 /model-integration
parentc0652d7794a90e0afb593fc1a3db17c99606a808 (diff)
Add missing wiring of pooling strategy
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java12
1 files changed, 1 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index b17ee89be7f..17b63fb1c7d 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -103,17 +103,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
- Tensor.Builder builder = Tensor.Builder.of(tensorType);
-
- // Mean pooling implementation
- Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
- Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
- Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) {
- builder.cell(averaged.get(TensorAddress.of(0,i)), i);
- }
-
- Tensor result = builder.build();
+ var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
return normalize ? normalize(result, tensorType) : result;
}