diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 01804656bb6..f93b1a3c1f8 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -1,5 +1,6 @@ package ai.vespa.embedding.huggingface; +import ai.vespa.embedding.PoolingStrategy; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; @@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; + private final PoolingStrategy poolingStrategy; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { @@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { .setTruncation(true) .setMaxLength(config.transformerMaxTokens()) .build(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); |