diff options
7 files changed, 174 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java index d0e1ede2cfa..3fad99eaa75 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java @@ -7,10 +7,7 @@ import com.yahoo.config.model.producer.AnyConfigProducer; import com.yahoo.config.model.producer.TreeConfigProducer; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; -import com.yahoo.vespa.model.container.component.BertEmbedder; -import com.yahoo.vespa.model.container.component.Component; -import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; -import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.*; import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder; import org.w3c.dom.Element; @@ -46,6 +43,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state); case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state); case "bert-embedder" -> new BertEmbedder(spec, state); + case "colbert-embedder" -> new ColBertEmbedder(spec, state); default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type)); }; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java new file mode 100644 index 00000000000..c0fdfe3dc64 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -0,0 +1,93 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.util.Optional; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; +import static com.yahoo.text.XML.getChildValue; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + + +/** + * @author bergum + */ +public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { + private final ModelReference model; + private final ModelReference vocab; + + private final Integer maxQueryTokens; + + private final Integer maxDocumentTokens; + + private final Integer transformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final Integer transformerMaskToken; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + + private final String transformerOutput; + private final String onnxExecutionMode; + private final Integer onnxInteropThreads; + private final Integer onnxIntraopThreads; + private final Integer onnxGpuDevice; + + public ColBertEmbedder(Element xml, DeployState state) { + super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); + var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); + model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); + vocab = getOptionalChild(xml, "tokenizer-model") + .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) + .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); + maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); + transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); + onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null); + onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); + onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); + onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + + } + + private static ModelReference resolveDefaultVocab(Element model, DeployState state) { + if (state.isHosted() && model.hasAttribute("model-id")) { + var implicitVocabId = model.getAttribute("model-id") + "-vocab"; + return ModelIdResolver.resolveToModelReference( + "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + } + throw new IllegalArgumentException("'tokenizer-model' must be specified"); + } + + @Override + public void getConfig(ColBertEmbedderConfig.Builder b) { + b.transformerModel(model).tokenizerPath(vocab); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens); + if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens); + if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); + if (onnxExecutionMode != null) b.transformerExecutionMode( + ColBertEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode)); + if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); + if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); + if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + } +} diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index ba7e2b6674e..e0d5e6a3344 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -80,7 +80,7 @@ ComponentDefinition = TypedComponentDefinition = attribute id { xsd:Name } & - (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) & + (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder) & GenericConfig* & Component* @@ -110,15 +110,36 @@ BertBaseEmbedder = element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & - element transformer-start-sequence-token { xsd:integer }? & - element transformer-end-sequence-token { xsd:integer }? & + StartOfSequence & + EndOfSequence & OnnxModelExecutionParams & EmbedderPoolingStrategy + +ColBertEmbedder = + attribute type { "colbert-embedder" } & + element transformer-model { ModelReference } & + element tokenizer-model { ModelReference }? & + element max-tokens { xsd:positiveInteger }? & + element max-query-tokens { xsd:positiveInteger }? & + element max-document-tokens { xsd:positiveInteger }? & + element transformer-mask-token { xsd:integer }? & + element transformer-input-ids { xsd:string }? & + element transformer-attention-mask { xsd:string }? & + element transformer-token-type-ids { xsd:string }? & + element transformer-output { xsd:string }? & + element normalize { xsd:boolean }? & + OnnxModelExecutionParams & + StartOfSequence & + EndOfSequence + OnnxModelExecutionParams = element onnx-execution-mode { "parallel" | "sequential" }? & element onnx-interop-threads { xsd:integer }? & element onnx-intraop-threads { xsd:integer }? & element onnx-gpu-device { xsd:integer }? -EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }?
\ No newline at end of file +EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? + +StartOfSequence = element transformer-start-sequence-token { xsd:integer }? +EndOfSequence = element transformer-end-sequence-token { xsd:integer }?
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 70eef7ea54a..efb33d36761 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -43,6 +43,25 @@ <onnx-gpu-device>1</onnx-gpu-device> </component> + <component id="colbert" type="colbert-embedder"> + <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> + <max-tokens>1024</max-tokens> + <max-query-tokens>32</max-query-tokens> + <max-document-tokens>512</max-document-tokens> + <transformer-start-sequence-token>101</transformer-start-sequence-token> + <transformer-end-sequence-token>102</transformer-end-sequence-token> + <transformer-mask-token>103</transformer-mask-token> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>10</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 42b78db66b1..5832445d0d7 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -21,6 +22,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; @@ -96,6 +98,29 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + void colBertEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + void colBertEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } @Test void bertEmbedder_selfhosted() throws Exception { @@ -233,6 +258,14 @@ public class EmbedderTestCase { return cfgBuilder.build(); } + private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); + var cfgBuilder = new ColBertEmbedderConfig.Builder(); + colbert.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 3069cb93444..aafb9877c27 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -1,3 +1,4 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index 70f91eb44ad..8516f6e6689 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -1,3 +1,4 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxRuntime; |