diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-25 15:32:04 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-04-25 15:32:04 +0200 |
commit | a374e90b5f95b3f3c533a4d0302ac0e66c32668f (patch) | |
tree | 5496e8603a071aea36a5f0f8d24bcdf266bb3b23 /config-model/src/main/java | |
parent | 117cace612ab00de27b8ec5e77896056e449bf33 (diff) |
add prepend support
Diffstat (limited to 'config-model/src/main/java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index fe0bb7c8075..91060579c4e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -12,6 +12,7 @@ import java.util.Set; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; +import static com.yahoo.text.XML.getChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; @@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Boolean normalize; private final String poolingStrategy; + private String prependQuery; + + private String prependDocument; + public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); @@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + Element prepend = getChild(xml, "prepend"); + if (prepend != null) { + prependQuery = getChildValue(prepend, "query").orElse(null); + prependDocument = getChildValue(prepend, "document").orElse(null); + } + model.registerOnnxModelCost(cluster, onnxModelOptions); } @@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (transformerOutput != null) b.transformerOutput(transformerOutput); if (normalize != null) b.normalize(normalize); if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy)); + if(prependQuery != null) b.prependQuery(prependQuery); + if(prependDocument != null) b.prependDocument(prependDocument); onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); |