aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java3
1 files changed, 3 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 01804656bb6..f93b1a3c1f8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,5 +1,6 @@
package ai.vespa.embedding.huggingface;
+import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
@@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
+ private final PoolingStrategy poolingStrategy;
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
@@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
.setTruncation(true)
.setMaxLength(config.transformerMaxTokens())
.build();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());