aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java3
1 files changed, 3 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index 1c36716699e..bb26b7e4fd7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -31,6 +31,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Integer onnxInteropThreads;
private final Integer onnxIntraopThreads;
private final Integer onnxGpuDevice;
+ private final String poolingStrategy;
public HuggingFaceEmbedder(Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
@@ -50,6 +51,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
}
private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
@@ -75,5 +77,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
}
}