summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorn.christian@seime.no>2024-04-26 07:56:30 +0200
committerGitHub <noreply@github.com>2024-04-26 07:56:30 +0200
commit95cccd4698808d320f41461697f74a1c9d161bac (patch)
treecd9336653627e4f3274a29f5f9da435a97b2150c /config-model
parentface477e177b1fdd5462a5c961e2d00d43895b97 (diff)
parenta374e90b5f95b3f3c533a4d0302ac0e66c32668f (diff)
Merge pull request #31049 from vespa-engine/jobergum/add-prepend-embedder-support
add prepend support to embedder
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java13
-rw-r--r--config-model/src/main/resources/schema/common.rnc6
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java4
4 files changed, 27 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index fe0bb7c8075..91060579c4e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -12,6 +12,7 @@ import java.util.Set;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
+import static com.yahoo.text.XML.getChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
@@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Boolean normalize;
private final String poolingStrategy;
+ private String prependQuery;
+
+ private String prependDocument;
+
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
@@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
+ Element prepend = getChild(xml, "prepend");
+ if (prepend != null) {
+ prependQuery = getChildValue(prepend, "query").orElse(null);
+ prependDocument = getChildValue(prepend, "document").orElse(null);
+ }
+
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
@@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (transformerOutput != null) b.transformerOutput(transformerOutput);
if (normalize != null) b.normalize(normalize);
if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy));
+ if(prependQuery != null) b.prependQuery(prependQuery);
+ if(prependDocument != null) b.prependDocument(prependDocument);
onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 14fae90678d..d949edbcacf 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -94,9 +94,15 @@ HuggingFaceEmbedder =
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element normalize { xsd:boolean }? &
+ PrependResources? &
OnnxModelExecutionParams &
EmbedderPoolingStrategy
+PrependResources = element prepend {
+ element query { xsd:string }? &
+ element document { xsd:string }?
+}
+
SpladeEmbedder =
attribute type { "splade-embedder" } &
element transformer-model { ModelReference } &
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 089adbd7517..ae0d0952630 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -12,6 +12,10 @@
<transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
<transformer-output>my_output</transformer-output>
<normalize>true</normalize>
+ <prepend>
+ <query>Represent this sentence for searching relevant passages:</query>
+ <document>passage:</document>
+ </prepend>
<onnx-execution-mode>parallel</onnx-execution-mode>
<onnx-intraop-threads>10</onnx-intraop-threads>
<onnx-interop-threads>8</onnx-interop-threads>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index fb1e176f707..6f629d99c92 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -88,6 +88,8 @@ public class EmbedderTestCase {
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
assertEquals(-1, tokenizerCfg.maxLength());
+ assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery());
+ assertEquals("passage:", embedderCfg.prependDocument());
}
@Test
@@ -101,6 +103,8 @@ public class EmbedderTestCase {
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
assertEquals(-1, tokenizerCfg.maxLength());
+ assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery());
+ assertEquals("passage:", embedderCfg.prependDocument());
}
@Test