aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorn.christian@seime.no>2024-04-26 07:56:30 +0200
committerGitHub <noreply@github.com>2024-04-26 07:56:30 +0200
commit95cccd4698808d320f41461697f74a1c9d161bac (patch)
treecd9336653627e4f3274a29f5f9da435a97b2150c /config-model/src/main/java/com/yahoo/vespa
parentface477e177b1fdd5462a5c961e2d00d43895b97 (diff)
parenta374e90b5f95b3f3c533a4d0302ac0e66c32668f (diff)
Merge pull request #31049 from vespa-engine/jobergum/add-prepend-embedder-support
add prepend support to embedder
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java13
1 files changed, 13 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index fe0bb7c8075..91060579c4e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -12,6 +12,7 @@ import java.util.Set;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
+import static com.yahoo.text.XML.getChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
@@ -34,6 +35,10 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Boolean normalize;
private final String poolingStrategy;
+ private String prependQuery;
+
+ private String prependDocument;
+
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
@@ -51,6 +56,12 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
+ Element prepend = getChild(xml, "prepend");
+ if (prepend != null) {
+ prependQuery = getChildValue(prepend, "query").orElse(null);
+ prependDocument = getChildValue(prepend, "document").orElse(null);
+ }
+
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
@@ -64,6 +75,8 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
if (transformerOutput != null) b.transformerOutput(transformerOutput);
if (normalize != null) b.normalize(normalize);
if (poolingStrategy != null) b.poolingStrategy(PoolingStrategy.Enum.valueOf(poolingStrategy));
+ if(prependQuery != null) b.prependQuery(prependQuery);
+ if(prependDocument != null) b.prependDocument(prependDocument);
onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);