diff options
author | Bjørn Christian Seime <bjorn.christian@seime.no> | 2024-04-26 07:56:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-26 07:56:30 +0200 |
commit | 95cccd4698808d320f41461697f74a1c9d161bac (patch) | |
tree | cd9336653627e4f3274a29f5f9da435a97b2150c /config-model | |
parent | face477e177b1fdd5462a5c961e2d00d43895b97 (diff) | |
parent | a374e90b5f95b3f3c533a4d0302ac0e66c32668f (diff) |
Merge pull request #31049 from vespa-engine/jobergum/add-prepend-embedder-support
add prepend support to embedder
Diffstat (limited to 'config-model')
4 files changed, 27 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index fe0bb7c8075..91060579c4e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -12,6 +12,7 @@ import java.util.Set; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; +import static com.yahoo.text.XML.getChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; @@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Boolean normalize; private final String poolingStrategy; + private String prependQuery; + + private String prependDocument; + public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); @@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + Element prepend = getChild(xml, "prepend"); + if (prepend != null) { + prependQuery = getChildValue(prepend, "query").orElse(null); + prependDocument = getChildValue(prepend, "document").orElse(null); + } + model.registerOnnxModelCost(cluster, onnxModelOptions); } @@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (transformerOutput != null) b.transformerOutput(transformerOutput); if (normalize != null) b.normalize(normalize); if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy)); + if(prependQuery != null) b.prependQuery(prependQuery); + if(prependDocument != null) b.prependDocument(prependDocument); onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 14fae90678d..d949edbcacf 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -94,9 +94,15 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & + PrependResources? & OnnxModelExecutionParams & EmbedderPoolingStrategy +PrependResources = element prepend { + element query { xsd:string }? & + element document { xsd:string }? +} + SpladeEmbedder = attribute type { "splade-embedder" } & element transformer-model { ModelReference } & diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 089adbd7517..ae0d0952630 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -12,6 +12,10 @@ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> <transformer-output>my_output</transformer-output> <normalize>true</normalize> + <prepend> + <query>Represent this sentence for searching relevant passages:</query> + <document>passage:</document> + </prepend> <onnx-execution-mode>parallel</onnx-execution-mode> <onnx-intraop-threads>10</onnx-intraop-threads> <onnx-interop-threads>8</onnx-interop-threads> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index fb1e176f707..6f629d99c92 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -88,6 +88,8 @@ public class EmbedderTestCase { var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); + assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery()); + assertEquals("passage:", embedderCfg.prependDocument()); } @Test @@ -101,6 +103,8 @@ public class EmbedderTestCase { var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); + assertEquals("Represent this sentence for searching relevant passages:", embedderCfg.prependQuery()); + assertEquals("passage:", embedderCfg.prependDocument()); } @Test |