diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-05 17:17:01 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-05 17:51:00 +0200 |
commit | b2e9037a14c2865d8c6377f9de3e07ad06627d9d (patch) | |
tree | 3204c48c4d3a69119801bdc83d13e5e36a196aae /config-model/src/main | |
parent | 684eb28e3b38ce015f5af060a5ba7b14f8f999c9 (diff) |
Make pooling strategy configurable for Huggingface embedder
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java | 3 | ||||
-rw-r--r-- | config-model/src/main/resources/schema/common.rnc | 11 |
2 files changed, 10 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 1c36716699e..bb26b7e4fd7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -31,6 +31,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Integer onnxInteropThreads; private final Integer onnxIntraopThreads; private final Integer onnxGpuDevice; + private final String poolingStrategy; public HuggingFaceEmbedder(Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); @@ -50,6 +51,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); } private static ModelReference resolveDefaultVocab(Element model, boolean hosted) { @@ -75,5 +77,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); } } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index b2b71950a0c..061e54740f1 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -94,7 +94,8 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & - OnnxModelExecutionParams + OnnxModelExecutionParams & + EmbedderPoolingStrategy HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & @@ -108,17 +109,19 @@ BertBaseEmbedder = element transformer-model { ModelReference } & element tokenizer-vocab { ModelReference } & element max-tokens { xsd:nonNegativeInteger }? & - element pooling-strategy { "cls" | "mean" }? & element transformer-input-ids { xsd:string }? & element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element transformer-start-sequence-token { xsd:integer }? & element transformer-end-sequence-token { xsd:integer }? & - OnnxModelExecutionParams + OnnxModelExecutionParams & + EmbedderPoolingStrategy OnnxModelExecutionParams = element onnx-execution-mode { "parallel" | "sequential" }? & element onnx-interop-threads { xsd:integer }? & element onnx-intraop-threads { xsd:integer }? & - element onnx-gpu-device { xsd:integer }?
\ No newline at end of file + element onnx-gpu-device { xsd:integer }? + +EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }?
\ No newline at end of file |