aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:17:01 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:51:00 +0200
commitb2e9037a14c2865d8c6377f9de3e07ad06627d9d (patch)
tree3204c48c4d3a69119801bdc83d13e5e36a196aae /config-model/src/main
parent684eb28e3b38ce015f5af060a5ba7b14f8f999c9 (diff)
Make pooling strategy configurable for Huggingface embedder
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java3
-rw-r--r--config-model/src/main/resources/schema/common.rnc11
2 files changed, 10 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index 1c36716699e..bb26b7e4fd7 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -31,6 +31,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Integer onnxInteropThreads;
private final Integer onnxIntraopThreads;
private final Integer onnxGpuDevice;
+ private final String poolingStrategy;
public HuggingFaceEmbedder(Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
@@ -50,6 +51,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
}
private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
@@ -75,5 +77,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index b2b71950a0c..061e54740f1 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -94,7 +94,8 @@ HuggingFaceEmbedder =
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element normalize { xsd:boolean }? &
- OnnxModelExecutionParams
+ OnnxModelExecutionParams &
+ EmbedderPoolingStrategy
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
@@ -108,17 +109,19 @@ BertBaseEmbedder =
element transformer-model { ModelReference } &
element tokenizer-vocab { ModelReference } &
element max-tokens { xsd:nonNegativeInteger }? &
- element pooling-strategy { "cls" | "mean" }? &
element transformer-input-ids { xsd:string }? &
element transformer-attention-mask { xsd:string }? &
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element transformer-start-sequence-token { xsd:integer }? &
element transformer-end-sequence-token { xsd:integer }? &
- OnnxModelExecutionParams
+ OnnxModelExecutionParams &
+ EmbedderPoolingStrategy
OnnxModelExecutionParams =
element onnx-execution-mode { "parallel" | "sequential" }? &
element onnx-interop-threads { xsd:integer }? &
element onnx-intraop-threads { xsd:integer }? &
- element onnx-gpu-device { xsd:integer }? \ No newline at end of file
+ element onnx-gpu-device { xsd:integer }?
+
+EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? \ No newline at end of file