summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java6
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java93
-rw-r--r--config-model/src/main/resources/schema/common.rnc29
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml19
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java33
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java1
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java1
7 files changed, 174 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
index d0e1ede2cfa..3fad99eaa75 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
@@ -7,10 +7,7 @@ import com.yahoo.config.model.producer.AnyConfigProducer;
import com.yahoo.config.model.producer.TreeConfigProducer;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.text.XML;
-import com.yahoo.vespa.model.container.component.BertEmbedder;
-import com.yahoo.vespa.model.container.component.Component;
-import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
-import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
+import com.yahoo.vespa.model.container.component.*;
import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder;
import org.w3c.dom.Element;
@@ -46,6 +43,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde
case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state);
case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state);
case "bert-embedder" -> new BertEmbedder(spec, state);
+ case "colbert-embedder" -> new ColBertEmbedder(spec, state);
default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type));
};
} else {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
new file mode 100644
index 00000000000..c0fdfe3dc64
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -0,0 +1,93 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.embedding.ColBertEmbedderConfig;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import java.util.Optional;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild;
+import static com.yahoo.text.XML.getChildValue;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+
+
+/**
+ * @author bergum
+ */
+public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer {
+ private final ModelReference model;
+ private final ModelReference vocab;
+
+ private final Integer maxQueryTokens;
+
+ private final Integer maxDocumentTokens;
+
+ private final Integer transformerStartSequenceToken;
+ private final Integer transformerEndSequenceToken;
+ private final Integer transformerMaskToken;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+
+ private final String transformerOutput;
+ private final String onnxExecutionMode;
+ private final Integer onnxInteropThreads;
+ private final Integer onnxIntraopThreads;
+ private final Integer onnxGpuDevice;
+
+ public ColBertEmbedder(Element xml, DeployState state) {
+ super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
+ var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
+ model = ModelIdResolver.resolveToModelReference(transformerModelElem, state);
+ vocab = getOptionalChild(xml, "tokenizer-model")
+ .map(elem -> ModelIdResolver.resolveToModelReference(elem, state))
+ .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state));
+ maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null);
+ maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null);
+ transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
+ onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null);
+ onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
+ onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
+ onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+
+ }
+
+ private static ModelReference resolveDefaultVocab(Element model, DeployState state) {
+ if (state.isHosted() && model.hasAttribute("model-id")) {
+ var implicitVocabId = model.getAttribute("model-id") + "-vocab";
+ return ModelIdResolver.resolveToModelReference(
+ "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state);
+ }
+ throw new IllegalArgumentException("'tokenizer-model' must be specified");
+ }
+
+ @Override
+ public void getConfig(ColBertEmbedderConfig.Builder b) {
+ b.transformerModel(model).tokenizerPath(vocab);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens);
+ if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens);
+ if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
+ if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
+ if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken);
+ if (onnxExecutionMode != null) b.transformerExecutionMode(
+ ColBertEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode));
+ if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
+ if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
+ if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ }
+}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index ba7e2b6674e..e0d5e6a3344 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -80,7 +80,7 @@ ComponentDefinition =
TypedComponentDefinition =
attribute id { xsd:Name } &
- (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) &
+ (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder) &
GenericConfig* &
Component*
@@ -110,15 +110,36 @@ BertBaseEmbedder =
element transformer-attention-mask { xsd:string }? &
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
- element transformer-start-sequence-token { xsd:integer }? &
- element transformer-end-sequence-token { xsd:integer }? &
+ StartOfSequence &
+ EndOfSequence &
OnnxModelExecutionParams &
EmbedderPoolingStrategy
+
+ColBertEmbedder =
+ attribute type { "colbert-embedder" } &
+ element transformer-model { ModelReference } &
+ element tokenizer-model { ModelReference }? &
+ element max-tokens { xsd:positiveInteger }? &
+ element max-query-tokens { xsd:positiveInteger }? &
+ element max-document-tokens { xsd:positiveInteger }? &
+ element transformer-mask-token { xsd:integer }? &
+ element transformer-input-ids { xsd:string }? &
+ element transformer-attention-mask { xsd:string }? &
+ element transformer-token-type-ids { xsd:string }? &
+ element transformer-output { xsd:string }? &
+ element normalize { xsd:boolean }? &
+ OnnxModelExecutionParams &
+ StartOfSequence &
+ EndOfSequence
+
OnnxModelExecutionParams =
element onnx-execution-mode { "parallel" | "sequential" }? &
element onnx-interop-threads { xsd:integer }? &
element onnx-intraop-threads { xsd:integer }? &
element onnx-gpu-device { xsd:integer }?
-EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? \ No newline at end of file
+EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }?
+
+StartOfSequence = element transformer-start-sequence-token { xsd:integer }?
+EndOfSequence = element transformer-end-sequence-token { xsd:integer }? \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 70eef7ea54a..efb33d36761 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -43,6 +43,25 @@
<onnx-gpu-device>1</onnx-gpu-device>
</component>
+ <component id="colbert" type="colbert-embedder">
+ <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
+ <max-tokens>1024</max-tokens>
+ <max-query-tokens>32</max-query-tokens>
+ <max-document-tokens>512</max-document-tokens>
+ <transformer-start-sequence-token>101</transformer-start-sequence-token>
+ <transformer-end-sequence-token>102</transformer-end-sequence-token>
+ <transformer-mask-token>103</transformer-mask-token>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>10</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
+ </component>
+
<nodes>
<node hostalias="node1" />
</nodes>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 42b78db66b1..5832445d0d7 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.embedding.BertBaseEmbedderConfig;
+import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -21,6 +22,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
+import com.yahoo.vespa.model.container.component.ColBertEmbedder;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
import com.yahoo.yolean.Exceptions;
import org.junit.jupiter.api.Test;
@@ -96,6 +98,29 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ void colBertEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertColBertEmbedderComponentPresent(cluster);
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
+
+ void colBertEmbedder_hosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertColBertEmbedderComponentPresent(cluster);
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
@Test
void bertEmbedder_selfhosted() throws Exception {
@@ -233,6 +258,14 @@ public class EmbedderTestCase {
return cfgBuilder.build();
}
+ private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder"));
+ assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName());
+ var cfgBuilder = new ColBertEmbedderConfig.Builder();
+ colbert.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 3069cb93444..aafb9877c27 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -1,3 +1,4 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 70f91eb44ad..8516f6e6689 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -1,3 +1,4 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;