aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java70
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java16
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java29
-rw-r--r--config-model/src/main/resources/schema/common.rnc32
-rw-r--r--config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def30
-rw-r--r--config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def26
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml26
-rw-r--r--config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def30
-rw-r--r--config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def26
-rw-r--r--config-model/src/test/cfg/application/embed_cloud_only/services.xml13
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java182
-rw-r--r--configdefinitions/src/vespa/embedding.bert-base-embedder.def (renamed from model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def)0
-rw-r--r--configdefinitions/src/vespa/hugging-face-embedder.def2
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java3
18 files changed, 346 insertions, 211 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
index fa0ee3f9857..d0e1ede2cfa 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
@@ -7,6 +7,7 @@ import com.yahoo.config.model.producer.AnyConfigProducer;
import com.yahoo.config.model.producer.TreeConfigProducer;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.text.XML;
+import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
@@ -44,6 +45,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde
return switch (type) {
case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state);
case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state);
+ case "bert-embedder" -> new BertEmbedder(spec, state);
default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type));
};
} else {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
new file mode 100644
index 00000000000..56aa974da48
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -0,0 +1,70 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.embedding.BertBaseEmbedderConfig;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
+import static com.yahoo.text.XML.getChild;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+
+/**
+ * @author bjorncs
+ */
+public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer {
+
+ private final ModelReference model;
+ private final ModelReference vocab;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+ private final Integer tranformerStartSequenceToken;
+ private final Integer transformerEndSequenceToken;
+ private final String poolingStrategy;
+ private final String onnxExecutionMode;
+ private final Integer onnxInteropThreads;
+ private final Integer onnxIntraopThreads;
+ private final Integer onnxGpuDevice;
+
+
+ public BertEmbedder(Element xml, DeployState state) {
+ super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
+ model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state);
+ vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state);
+ maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null);
+ tranformerStartSequenceToken = getOptionalChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerEndSequenceToken = getOptionalChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
+ onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null);
+ onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
+ onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
+ onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ }
+
+ @Override
+ public void getConfig(BertBaseEmbedderConfig.Builder b) {
+ b.transformerModel(model).tokenizerVocab(vocab);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (tranformerStartSequenceToken != null) b.transformerStartSequenceToken(tranformerStartSequenceToken);
+ if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
+ if (poolingStrategy != null) b.poolingStrategy(BertBaseEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
+ if (onnxExecutionMode != null) b.onnxExecutionMode(BertBaseEmbedderConfig.OnnxExecutionMode.Enum.valueOf(onnxExecutionMode));
+ if (onnxInteropThreads != null) b.onnxInterOpThreads(onnxInteropThreads);
+ if (onnxIntraopThreads != null) b.onnxIntraOpThreads(onnxIntraopThreads);
+ if (onnxGpuDevice != null) b.onnxGpuDevice(onnxGpuDevice);
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index 1c36716699e..6e7a1cc31dd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -31,15 +31,15 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Integer onnxInteropThreads;
private final Integer onnxIntraopThreads;
private final Integer onnxGpuDevice;
+ private final String poolingStrategy;
public HuggingFaceEmbedder(Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- boolean hosted = state.isHosted();
var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
- model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted);
+ model = ModelIdResolver.resolveToModelReference(transformerModelElem, state);
vocab = getOptionalChild(xml, "tokenizer-model")
- .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted))
- .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted));
+ .map(elem -> ModelIdResolver.resolveToModelReference(elem, state))
+ .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state));
maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
@@ -50,13 +50,14 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
}
- private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
- if (hosted && model.hasAttribute("model-id")) {
+ private static ModelReference resolveDefaultVocab(Element model, DeployState state) {
+ if (state.isHosted() && model.hasAttribute("model-id")) {
var implicitVocabId = model.getAttribute("model-id") + "-vocab";
return ModelIdResolver.resolveToModelReference(
- "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true);
+ "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state);
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
@@ -75,5 +76,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index ba8521a0089..966dbe8260a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -28,7 +28,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
for (Element element : XML.getChildren(xml, "model")) {
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
- langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted()));
+ langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state));
}
specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
index c0f49f3148d..96f653bf793 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -1,10 +1,10 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.xml;
-import com.yahoo.config.FileReference;
import com.yahoo.config.ModelReference;
import com.yahoo.config.UrlReference;
import com.yahoo.config.model.builder.xml.XmlHelper;
+import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.text.XML;
import org.w3c.dom.Element;
@@ -80,25 +80,24 @@ public class ModelIdResolver {
}
- public static ModelReference resolveToModelReference(Element elem, boolean hosted) {
+ public static ModelReference resolveToModelReference(Element elem, DeployState state) {
return resolveToModelReference(
elem.getTagName(), XmlHelper.getOptionalAttribute(elem, "model-id"),
- XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), hosted);
+ XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), state);
}
public static ModelReference resolveToModelReference(
- String paramName, Optional<String> id, Optional<String> url, Optional<String> path, boolean hosted) {
- if (id.isEmpty()) return ModelReference.unresolved(
- Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new));
- else if (hosted) {
- return ModelReference.unresolved(
- id, Optional.of(UrlReference.valueOf(modelIdToUrl(paramName, id.get()))), Optional.empty());
- } else if (url.isEmpty() && path.isEmpty()) {
- throw onlyModelIdInHostedException(paramName);
- } else {
- return ModelReference.unresolved(
- Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new));
- }
+ String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) {
+ if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state);
+ else if (state.isHosted())
+ return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get())), Optional.empty(), state);
+ else if (url.isEmpty() && path.isEmpty()) throw onlyModelIdInHostedException(paramName);
+ else return createModelReference(id, url, path, state);
+ }
+
+ private static ModelReference createModelReference(Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) {
+ var fileRef = path.map(p -> state.getFileRegistry().addFile(p));
+ return ModelReference.unresolved(id, url.map(UrlReference::valueOf), fileRef);
}
private static IllegalArgumentException onlyModelIdInHostedException(String paramName) {
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 4e7cb526efb..061e54740f1 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -80,7 +80,7 @@ ComponentDefinition =
TypedComponentDefinition =
attribute id { xsd:Name } &
- (HuggingFaceEmbedder | HuggingFaceTokenizer) &
+ (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) &
GenericConfig* &
Component*
@@ -94,14 +94,34 @@ HuggingFaceEmbedder =
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element normalize { xsd:boolean }? &
- element onnx-execution-mode { "parallel" | "sequential" }? &
- element onnx-interop-threads { xsd:integer }? &
- element onnx-intraop-threads { xsd:integer }? &
- element onnx-gpu-device { xsd:integer }?
+ OnnxModelExecutionParams &
+ EmbedderPoolingStrategy
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
element model { attribute language { xsd:string }? & ModelReference }+ &
element special-tokens { xsd:boolean }? &
element max-length { xsd:integer }? &
- element truncation { xsd:boolean }? \ No newline at end of file
+ element truncation { xsd:boolean }?
+
+BertBaseEmbedder =
+ attribute type { "bert-embedder" } &
+ element transformer-model { ModelReference } &
+ element tokenizer-vocab { ModelReference } &
+ element max-tokens { xsd:nonNegativeInteger }? &
+ element transformer-input-ids { xsd:string }? &
+ element transformer-attention-mask { xsd:string }? &
+ element transformer-token-type-ids { xsd:string }? &
+ element transformer-output { xsd:string }? &
+ element transformer-start-sequence-token { xsd:integer }? &
+ element transformer-end-sequence-token { xsd:integer }? &
+ OnnxModelExecutionParams &
+ EmbedderPoolingStrategy
+
+OnnxModelExecutionParams =
+ element onnx-execution-mode { "parallel" | "sequential" }? &
+ element onnx-interop-threads { xsd:integer }? &
+ element onnx-intraop-threads { xsd:integer }? &
+ element onnx-gpu-device { xsd:integer }?
+
+EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def b/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def
deleted file mode 100644
index 144dfbd0001..00000000000
--- a/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copy of this Vespa config stored here because Vespa config definitions are not
-# available in unit tests, and are needed (by DomConfigPayloadBuilder.parseLeaf)
-# Alternatively, we could make that not need it as it is not strictly necessaery.
-
-namespace=embedding
-
-# Wordpiece tokenizer
-tokenizerVocab model
-
-transformerModel model
-
-# Max length of token sequence model can handle
-transformerMaxTokens int default=384
-
-# Pooling strategy
-poolingStrategy enum { cls, mean } default=mean
-
-# Input names
-transformerInputIds string default=input_ids
-transformerAttentionMask string default=attention_mask
-transformerTokenTypeIds string default=token_type_ids
-
-# Output name
-transformerOutput string default=output_0
-
-# Settings for ONNX model evaluation
-onnxExecutionMode enum { parallel, sequential } default=sequential
-onnxInterOpThreads int default=1
-onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
-
diff --git a/config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def
new file mode 100644
index 00000000000..87b80f1051a
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def
@@ -0,0 +1,26 @@
+package=ai.vespa.example.paragraph
+
+# WordPiece tokenizer vocabulary
+vocab model
+
+model model
+
+myValue string
+
+# Max length of token sequence model can handle
+transforerMaxTokens int default=128
+
+# Pooling strategy
+poolingStrategy enum { cls, mean } default=mean
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+
+# Output name
+transformerOutput string default=last_hidden_state
+
+# Settings for ONNX model evaluation
+onnxExecutionMode enum { parallel, sequential } default=sequential
+onnxInterOpThreads int default=1
+onnxIntraOpThreads int default=-4
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 99c89bc4324..6823ef900ae 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -16,6 +16,7 @@
<onnx-intraop-threads>10</onnx-intraop-threads>
<onnx-interop-threads>8</onnx-interop-threads>
<onnx-gpu-device>1</onnx-gpu-device>
+ <pooling-strategy>mean</pooling-strategy>
</component>
<component id="hf-tokenizer" type="hugging-face-tokenizer">
@@ -25,15 +26,24 @@
<truncation>true</truncation>
</component>
- <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration">
- <config name="embedding.bert-base-embedder">
- <!-- model specifics -->
- <transformerModel model-id="minilm-l6-v2" url="application-url"/>
- <tokenizerVocab path="files/vocab.txt"/>
+ <component id="bert-embedder" type="bert-embedder">
+ <!-- model specifics -->
+ <transformer-model model-id="minilm-l6-v2" url="application-url"/>
+ <tokenizer-vocab path="files/vocab.txt"/>
+ <max-tokens>512</max-tokens>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <transformer-start-sequence-token>101</transformer-start-sequence-token>
+ <transformer-end-sequence-token>102</transformer-end-sequence-token>
- <!-- tunable parameters: number of threads etc -->
- <onnxIntraOpThreads>4</onnxIntraOpThreads>
- </config>
+
+ <!-- tunable parameters: number of threads etc -->
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>4</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
</component>
<nodes>
diff --git a/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def b/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def
deleted file mode 100644
index 144dfbd0001..00000000000
--- a/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copy of this Vespa config stored here because Vespa config definitions are not
-# available in unit tests, and are needed (by DomConfigPayloadBuilder.parseLeaf)
-# Alternatively, we could make that not need it as it is not strictly necessaery.
-
-namespace=embedding
-
-# Wordpiece tokenizer
-tokenizerVocab model
-
-transformerModel model
-
-# Max length of token sequence model can handle
-transformerMaxTokens int default=384
-
-# Pooling strategy
-poolingStrategy enum { cls, mean } default=mean
-
-# Input names
-transformerInputIds string default=input_ids
-transformerAttentionMask string default=attention_mask
-transformerTokenTypeIds string default=token_type_ids
-
-# Output name
-transformerOutput string default=output_0
-
-# Settings for ONNX model evaluation
-onnxExecutionMode enum { parallel, sequential } default=sequential
-onnxInterOpThreads int default=1
-onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
-
diff --git a/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def
new file mode 100644
index 00000000000..87b80f1051a
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def
@@ -0,0 +1,26 @@
+package=ai.vespa.example.paragraph
+
+# WordPiece tokenizer vocabulary
+vocab model
+
+model model
+
+myValue string
+
+# Max length of token sequence model can handle
+transforerMaxTokens int default=128
+
+# Pooling strategy
+poolingStrategy enum { cls, mean } default=mean
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+
+# Output name
+transformerOutput string default=last_hidden_state
+
+# Settings for ONNX model evaluation
+onnxExecutionMode enum { parallel, sequential } default=sequential
+onnxInterOpThreads int default=1
+onnxIntraOpThreads int default=-4
diff --git a/config-model/src/test/cfg/application/embed_cloud_only/services.xml b/config-model/src/test/cfg/application/embed_cloud_only/services.xml
index 57db4f5bfae..e203ec56669 100644
--- a/config-model/src/test/cfg/application/embed_cloud_only/services.xml
+++ b/config-model/src/test/cfg/application/embed_cloud_only/services.xml
@@ -4,14 +4,11 @@
<container version="1.0">
- <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration">
- <config name="embedding.bert-base-embedder">
- <!-- No fallback to url or path when deploying outside cloud -->
- <transformerModel model-id="minilm-l6-v2"/>
- <tokenizerVocab path="files/vocab.txt"/>
-
- <!-- tunable parameters: number of threads etc -->
- <onnxIntraOpThreads>4</onnxIntraOpThreads>
+ <component id="transformer" class="ai.vespa.example.paragraph.ApplicationSpecificEmbedder" bundle="app">
+ <config name='ai.vespa.example.paragraph.sentence-embedder'>
+ <model model-id="minilm-l6-v2"/>
+ <vocab path="files/vocab.txt"/>
+ <myValue>foo</myValue>
</config>
</component>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 69981233c3f..2a82daef9e3 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -2,9 +2,13 @@
package com.yahoo.vespa.model.container.xml;
import com.yahoo.component.ComponentId;
+import com.yahoo.config.InnerNode;
+import com.yahoo.config.ModelNode;
+import com.yahoo.config.ModelReference;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -13,6 +17,7 @@ import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.vespa.config.ConfigPayloadBuilder;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
@@ -35,55 +40,18 @@ import static org.junit.jupiter.api.Assertions.fail;
public class EmbedderTestCase {
- private static final String BUNDLED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder";
- private static final String BUNDLED_EMBEDDER_CONFIG = "embedding.bert-base-embedder";
-
- @Test
- void testBundledEmbedder_selfhosted() throws IOException, SAXException {
- String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel id='my_model_id' url='my-model-url' />" +
- " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" +
- " </config>" +
- "</component>";
- String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel id='my_model_id' url='my-model-url' />" +
- " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" +
- " </config>" +
- "</component>";
- assertTransform(input, component, false);
- }
-
- @Test
- void testBundledEmbedder_hosted() throws IOException, SAXException {
- String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel model-id='minilm-l6-v2' />" +
- " <tokenizerVocab model-id='bert-base-uncased' path='ignored.txt'/>" +
- " </config>" +
- "</component>";
- String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" +
- " <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" +
- " </config>" +
- "</component>";
- assertTransform(input, component, true);
- }
-
@Test
void testApplicationComponentWithModelReference_hosted() throws IOException, SAXException {
- String input = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel model-id='minilm-l6-v2' />" +
- " <tokenizerVocab model-id='bert-base-uncased' />" +
+ String input = "<component id='test' class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' bundle='app'>" +
+ " <config name='ai.vespa.example.paragraph.sentence-embedder'>" +
+ " <model model-id='minilm-l6-v2' />" +
+ " <vocab model-id='bert-base-uncased' />" +
" </config>" +
"</component>";
- String component = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" +
- " <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" +
+ String component = "<component id='test' class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' bundle='app'>" +
+ " <config name='ai.vespa.example.paragraph.sentence-embedder'>" +
+ " <model model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" +
+ " <vocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" +
" </config>" +
"</component>";
assertTransform(input, component, true);
@@ -91,64 +59,65 @@ public class EmbedderTestCase {
@Test
void testUnknownModelId_hosted() throws IOException, SAXException {
- String embedder = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel model-id='my_model_id' />" +
- " <tokenizerVocab model-id='my_vocab_id' />" +
+ String embedder = "<component id='test' class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder'>" +
+ " <config name='ai.vespa.example.paragraph.sentence-embedder'>" +
+ " <model model-id='my_model_id' />" +
+ " <vocab model-id='my_vocab_id' />" +
" </config>" +
"</component>";
assertTransformThrows(embedder,
- "Unknown model id 'my_model_id' on 'transformerModel'",
+ "Unknown model id 'my_model_id' on 'model'",
true);
}
@Test
- void testApplicationPackageWithEmbedder_selfhosted() throws Exception {
- Path applicationDir = Path.fromString("src/test/cfg/application/embed/");
- VespaModel model = loadModel(applicationDir, false);
- ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
+ void huggingfaceEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(768, tokenizerCfg.maxLength());
+ }
- Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
- ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
- assertEquals("minilm-l6-v2 application-url \"\"", config.getObject("transformerModel").getValue());
- assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue());
- assertEquals("4", config.getObject("onnxIntraOpThreads").getValue());
-
- {
- var hfEmbedder = (HuggingFaceEmbedder)containerCluster.getComponentsMap().get(new ComponentId("hf-embedder"));
- assertEquals("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", hfEmbedder.getClassId().getName());
- var cfgBuilder = new HuggingFaceEmbedderConfig.Builder();
- hfEmbedder.getConfig(cfgBuilder);
- var cfg = cfgBuilder.build();
- assertEquals("my_input_ids", cfg.transformerInputIds());
- }
- {
- var hfTokenizer = (HuggingFaceTokenizer)containerCluster.getComponentsMap().get(new ComponentId("hf-tokenizer"));
- assertEquals("com.yahoo.language.huggingface.HuggingFaceTokenizer", hfTokenizer.getClassId().getName());
- var cfgBuilder = new HuggingFaceTokenizerConfig.Builder();
- hfTokenizer.getConfig(cfgBuilder);
- var cfg = cfgBuilder.build();
- assertEquals(768, cfg.maxLength());
- }
+ @Test
+ void huggingfaceEmbedder_hosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(768, tokenizerCfg.maxLength());
}
+
@Test
- void passesXmlValdiation() {
- new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create();
+ void bertEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertBertEmbedderComponentPresent(cluster);
+ assertEquals("application-url", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value());
}
@Test
- void testApplicationPackageWithEmbedder_hosted() throws Exception {
- Path applicationDir = Path.fromString("src/test/cfg/application/embed/");
- VespaModel model = loadModel(applicationDir, true);
- ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
+ void bertEmbedder_hosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertBertEmbedderComponentPresent(cluster);
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
+ modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertTrue(modelReference(embedderCfg, "tokenizerVocab").url().isEmpty());
+ assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value());
+ }
- Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
- ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
- assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"",
- config.getObject("transformerModel").getValue());
- assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue());
- assertEquals("4", config.getObject("onnxIntraOpThreads").getValue());
+ @Test
+ void passesXmlValidation() {
+ new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create();
}
@Test
@@ -184,7 +153,7 @@ public class EmbedderTestCase {
fail("Expected failure");
}
catch (IllegalArgumentException e) {
- assertEquals("transformerModel is configured with only a 'model-id'. Add a 'path' or 'url' to deploy this outside Vespa Cloud",
+ assertEquals("model is configured with only a 'model-id'. Add a 'path' or 'url' to deploy this outside Vespa Cloud",
Exceptions.toMessageString(e));
}
}
@@ -244,4 +213,39 @@ public class EmbedderTestCase {
return (Element) doc.getFirstChild();
}
+ private static HuggingFaceTokenizerConfig assertHuggingfaceTokenizerComponentPresent(ApplicationContainerCluster cluster) {
+ var hfTokenizer = (HuggingFaceTokenizer) cluster.getComponentsMap().get(new ComponentId("hf-tokenizer"));
+ assertEquals("com.yahoo.language.huggingface.HuggingFaceTokenizer", hfTokenizer.getClassId().getName());
+ var cfgBuilder = new HuggingFaceTokenizerConfig.Builder();
+ hfTokenizer.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
+ private static HuggingFaceEmbedderConfig assertHuggingfaceEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var hfEmbedder = (HuggingFaceEmbedder) cluster.getComponentsMap().get(new ComponentId("hf-embedder"));
+ assertEquals("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", hfEmbedder.getClassId().getName());
+ var cfgBuilder = new HuggingFaceEmbedderConfig.Builder();
+ hfEmbedder.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
+ private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
+ assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());
+ var cfgBuilder = new BertBaseEmbedderConfig.Builder();
+ bertEmbedder.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
+ // Ugly hack to read underlying model reference from config instance
+ private static ModelReference modelReference(InnerNode cfg, String name) {
+ try {
+ var f = cfg.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return ((ModelNode) f.get(cfg)).getModelReference();
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
}
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/configdefinitions/src/vespa/embedding.bert-base-embedder.def
index 2d8e840377b..2d8e840377b 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/configdefinitions/src/vespa/embedding.bert-base-embedder.def
diff --git a/configdefinitions/src/vespa/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def
index 36957004e02..7ea4227b3cd 100644
--- a/configdefinitions/src/vespa/hugging-face-embedder.def
+++ b/configdefinitions/src/vespa/hugging-face-embedder.def
@@ -21,6 +21,8 @@ transformerOutput string default=last_hidden_state
# Normalize tensors from tokenizer
normalize bool default=false
+poolingStrategy enum { cls, mean } default=mean
+
# Settings for ONNX model evaluation
transformerExecutionMode enum { parallel, sequential } default=sequential
transformerInterOpThreads int default=1
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b172ef7beee..a12424c7d12 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.language.wordpiece.WordPieceEmbedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
@@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
private final String attentionMaskName;
private final String tokenTypeIdsName;
private final String outputName;
- private final String poolingStrategy;
+ private final PoolingStrategy poolingStrategy;
private final WordPieceEmbedder tokenizer;
private final OnnxEvaluator evaluator;
@@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
- poolingStrategy = config.poolingStrategy().toString();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
@@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor tokenEmbeddings = outputs.get(outputName);
- Tensor.Builder builder = Tensor.Builder.of(type);
- if (poolingStrategy.equals("mean")) { // average over tokens
- Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
- Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
- Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(averaged.get(TensorAddress.of(0,i)), i);
- }
- } else { // CLS - use first token
- for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
- builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
- }
- }
- return builder.build();
+ return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask);
}
private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
new file mode 100644
index 00000000000..28104d8eeef
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+/**
+ * @author bjorncs
+ */
+public enum PoolingStrategy {
+ MEAN {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) {
+ var builder = Tensor.Builder.of(type);
+ var summedEmbeddings = tokenEmbeddings.sum("d1");
+ var summedAttentionMask = attentionMask.expand("d0").sum("d1");
+ var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(averaged.get(TensorAddress.of(0, i)), i);
+ }
+ return builder.build();
+ }
+ },
+ CLS {
+ @Override
+ public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) {
+ var builder = Tensor.Builder.of(type);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
+ }
+ return builder.build();
+ }
+ };
+
+ public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask);
+
+ public static PoolingStrategy fromString(String strategy) {
+ return switch (strategy.toLowerCase()) {
+ case "mean" -> MEAN;
+ case "cls" -> CLS;
+ default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy));
+ };
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 01804656bb6..f93b1a3c1f8 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,5 +1,6 @@
package ai.vespa.embedding.huggingface;
+import ai.vespa.embedding.PoolingStrategy;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
@@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final boolean normalize;
private final HuggingFaceTokenizer tokenizer;
private final OnnxEvaluator evaluator;
+ private final PoolingStrategy poolingStrategy;
@Inject
public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
@@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
.setTruncation(true)
.setMaxLength(config.transformerMaxTokens())
.build();
+ poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());