diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 1c36716699e..bb26b7e4fd7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -31,6 +31,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Integer onnxInteropThreads; private final Integer onnxIntraopThreads; private final Integer onnxGpuDevice; + private final String poolingStrategy; public HuggingFaceEmbedder(Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); @@ -50,6 +51,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); } private static ModelReference resolveDefaultVocab(Element model, boolean hosted) { @@ -75,5 +77,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); } } |