summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-04-25 15:32:04 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2024-04-25 15:32:04 +0200
commita374e90b5f95b3f3c533a4d0302ac0e66c32668f (patch)
tree5496e8603a071aea36a5f0f8d24bcdf266bb3b23 /config-model/src/main/java/com
parent117cace612ab00de27b8ec5e77896056e449bf33 (diff)
add prepend support
Diffstat (limited to 'config-model/src/main/java/com')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java13
1 files changed, 13 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index fe0bb7c8075..91060579c4e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -12,6 +12,7 @@ import java.util.Set;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
+import static com.yahoo.text.XML.getChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
@@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Boolean normalize;
private final String poolingStrategy;
+ private String prependQuery;
+
+ private String prependDocument;
+
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
@@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
+ Element prepend = getChild(xml, "prepend");
+ if (prepend != null) {
+ prependQuery = getChildValue(prepend, "query").orElse(null);
+ prependDocument = getChildValue(prepend, "document").orElse(null);
+ }
+
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
@@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (transformerOutput != null) b.transformerOutput(transformerOutput);
if (normalize != null) b.normalize(normalize);
if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy));
+ if(prependQuery != null) b.prependQuery(prependQuery);
+ if(prependDocument != null) b.prependDocument(prependDocument);
onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);