diff options
author | Bjørn Christian Seime <bjorn.christian@seime.no> | 2024-04-26 07:56:30 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-26 07:56:30 +0200 |
commit | 95cccd4698808d320f41461697f74a1c9d161bac (patch) | |
tree | cd9336653627e4f3274a29f5f9da435a97b2150c /config-model/src/main | |
parent | face477e177b1fdd5462a5c961e2d00d43895b97 (diff) | |
parent | a374e90b5f95b3f3c533a4d0302ac0e66c32668f (diff) |
Merge pull request #31049 from vespa-engine/jobergum/add-prepend-embedder-support
add prepend support to embedder
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java | 13 | ||||
-rw-r--r-- | config-model/src/main/resources/schema/common.rnc | 6 |
2 files changed, 19 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index fe0bb7c8075..91060579c4e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -12,6 +12,7 @@ import java.util.Set; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; +import static com.yahoo.text.XML.getChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; @@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Boolean normalize; private final String poolingStrategy; + private String prependQuery; + + private String prependDocument; + public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); @@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + Element prepend = getChild(xml, "prepend"); + if (prepend != null) { + prependQuery = getChildValue(prepend, "query").orElse(null); + prependDocument = getChildValue(prepend, "document").orElse(null); + } + model.registerOnnxModelCost(cluster, onnxModelOptions); } @@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (transformerOutput != null) b.transformerOutput(transformerOutput); if (normalize != null) b.normalize(normalize); if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy)); + if(prependQuery != null) b.prependQuery(prependQuery); + if(prependDocument != null) b.prependDocument(prependDocument); onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 14fae90678d..d949edbcacf 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -94,9 +94,15 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & + PrependResources? & OnnxModelExecutionParams & EmbedderPoolingStrategy +PrependResources = element prepend { + element query { xsd:string }? & + element document { xsd:string }? +} + SpladeEmbedder = attribute type { "splade-embedder" } & element transformer-model { ModelReference } & |