summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-12-15 08:42:01 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2023-12-15 08:42:01 +0100
commitcbc0733c07b57d8563eea40897072cb35042b605 (patch)
treefa3b564329dc4a88b48c697692c7896d6d4b36b0 /config-model
parent8af800ba588f726184ffb8296463bb4b7fbea5a1 (diff)
Add a splade embedder implementation
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java73
-rw-r--r--config-model/src/main/resources/schema/common.rnc14
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml15
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java30
5 files changed, 132 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
index 72b72c369dc..3dcc74e777d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
@@ -8,11 +8,7 @@ import com.yahoo.config.model.producer.TreeConfigProducer;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
-import com.yahoo.vespa.model.container.component.BertEmbedder;
-import com.yahoo.vespa.model.container.component.ColBertEmbedder;
-import com.yahoo.vespa.model.container.component.Component;
-import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
-import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
+import com.yahoo.vespa.model.container.component.*;
import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder;
import org.w3c.dom.Element;
@@ -50,6 +46,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde
case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state);
case "colbert-embedder" -> new ColBertEmbedder((ApplicationContainerCluster)ancestor, spec, state);
case "bert-embedder" -> new BertEmbedder((ApplicationContainerCluster)ancestor, spec, state);
+ case "splade-embedder" -> new SpladeEmbedder((ApplicationContainerCluster)ancestor, spec, state);
default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type));
};
} else {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
new file mode 100644
index 00000000000..96554e91d38
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
@@ -0,0 +1,73 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.api.OnnxModelOptions;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.embedding.SpladeEmbedderConfig;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+import org.w3c.dom.Element;
+import static com.yahoo.text.XML.getChildValue;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+
+public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConfig.Producer {
+
+ private final OnnxModelOptions onnxModelOptions;
+
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+
+ private final Double termScoreThreshold;
+
+ public SpladeEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
+ super("ai.vespa.embedding.SpladeEmbedder", INTEGRATION_BUNDLE_NAME, xml);
+ var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ this.onnxModelOptions = new OnnxModelOptions(
+ getChildValue(xml, "onnx-execution-mode"),
+ getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
+ getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state));
+ maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
+ termScoreThreshold = getChildValue(xml, "term-score-threshold").map(Double::parseDouble).orElse(null);
+ model.registerOnnxModelCost(cluster, onnxModelOptions);
+ }
+
+ private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
+ var modelId = model.modelId().orElse(null);
+ if (state.isHosted() && modelId != null) {
+ return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
+ }
+ throw new IllegalArgumentException("'tokenizer-model' must be specified");
+ }
+
+
+ @Override
+ public void getConfig(SpladeEmbedderConfig.Builder b) {
+ b.transformerModel(modelRef).tokenizerPath(vocabRef);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (termScoreThreshold != null) b.termScoreThreshold(termScoreThreshold);
+
+ onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(SpladeEmbedderConfig.TransformerExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
+ }
+}
+
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 2f3b10742f3..919253977ca 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -80,7 +80,7 @@ ComponentDefinition =
TypedComponentDefinition =
attribute id { xsd:Name } &
- (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder) &
+ (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder | SpladeEmbedder ) &
GenericConfig* &
Component*
@@ -97,6 +97,18 @@ HuggingFaceEmbedder =
OnnxModelExecutionParams &
EmbedderPoolingStrategy
+SpladeEmbedder =
+ attribute type { "splade-embedder" } &
+ element transformer-model { ModelReference } &
+ element tokenizer-model { ModelReference }? &
+ element max-tokens { xsd:positiveInteger }? &
+ element transformer-input-ids { xsd:string }? &
+ element transformer-attention-mask { xsd:string }? &
+ element transformer-token-type-ids { xsd:string }? &
+ element transformer-output { xsd:string }? &
+ element term-score-threshold { xsd:double }? &
+ OnnxModelExecutionParams
+
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
element model { attribute language { xsd:string }? & ModelReference }+
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 1840063d70d..59c29aefc6a 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -19,6 +19,21 @@
<pooling-strategy>mean</pooling-strategy>
</component>
+ <component id="splade" type="splade-embedder">
+ <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
+ <max-tokens>1024</max-tokens>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <term-score-threshold>0.2</term-score-threshold>
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>10</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
+ </component>
+
<component id="hf-tokenizer" type="hugging-face-tokenizer">
<model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/>
</component>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 138bef3ae73..2532a5be863 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -12,6 +12,7 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.embedding.ColBertEmbedderConfig;
+import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -24,6 +25,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.ColBertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
+import com.yahoo.vespa.model.container.component.SpladeEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
import com.yahoo.yolean.Exceptions;
@@ -101,6 +103,23 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ @Test
+ void spladeEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertSpladeEmbedderComponentPresent(cluster);
+
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(0.2, embedderCfg.termScoreThreshold());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
+
+ @Test
void colBertEmbedder_selfhosted() throws Exception {
var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
var cluster = model.getContainerClusters().get("container");
@@ -113,6 +132,7 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ @Test
void colBertEmbedder_hosted() throws Exception {
var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
var cluster = model.getContainerClusters().get("container");
@@ -266,13 +286,21 @@ public class EmbedderTestCase {
}
private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
- var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder"));
+ var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert"));
assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName());
var cfgBuilder = new ColBertEmbedderConfig.Builder();
colbert.getConfig(cfgBuilder);
return cfgBuilder.build();
}
+ private static SpladeEmbedderConfig assertSpladeEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var splade = (SpladeEmbedder) cluster.getComponentsMap().get(new ComponentId("splade"));
+ assertEquals("ai.vespa.embedding.SpladeEmbedder", splade.getClassId().getName());
+ var cfgBuilder = new SpladeEmbedderConfig.Builder();
+ splade.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());