diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-15 08:42:01 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-15 08:42:01 +0100 |
commit | cbc0733c07b57d8563eea40897072cb35042b605 (patch) | |
tree | fa3b564329dc4a88b48c697692c7896d6d4b36b0 /config-model | |
parent | 8af800ba588f726184ffb8296463bb4b7fbea5a1 (diff) |
Add a splade embedder implementation
Diffstat (limited to 'config-model')
5 files changed, 132 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java index 72b72c369dc..3dcc74e777d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java @@ -8,11 +8,7 @@ import com.yahoo.config.model.producer.TreeConfigProducer; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; import com.yahoo.vespa.model.container.ApplicationContainerCluster; -import com.yahoo.vespa.model.container.component.BertEmbedder; -import com.yahoo.vespa.model.container.component.ColBertEmbedder; -import com.yahoo.vespa.model.container.component.Component; -import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; -import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.*; import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder; import org.w3c.dom.Element; @@ -50,6 +46,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state); case "colbert-embedder" -> new ColBertEmbedder((ApplicationContainerCluster)ancestor, spec, state); case "bert-embedder" -> new BertEmbedder((ApplicationContainerCluster)ancestor, spec, state); + case "splade-embedder" -> new SpladeEmbedder((ApplicationContainerCluster)ancestor, spec, state); default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type)); }; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java new file mode 100644 index 00000000000..96554e91d38 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java @@ -0,0 +1,73 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.SpladeEmbedderConfig; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import org.w3c.dom.Element; +import static com.yahoo.text.XML.getChildValue; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + +public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConfig.Producer { + + private final OnnxModelOptions onnxModelOptions; + + private final ModelReference modelRef; + private final ModelReference vocabRef; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + private final String transformerTokenTypeIds; + private final String transformerOutput; + + private final Double termScoreThreshold; + + public SpladeEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { + super("ai.vespa.embedding.SpladeEmbedder", INTEGRATION_BUNDLE_NAME, xml); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + this.onnxModelOptions = new OnnxModelOptions( + getChildValue(xml, "onnx-execution-mode"), + getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), + getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerTokenTypeIds = getChildValue(xml, "transformer-token-type-ids").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); + termScoreThreshold = getChildValue(xml, "term-score-threshold").map(Double::parseDouble).orElse(null); + model.registerOnnxModelCost(cluster, onnxModelOptions); + } + + private static ModelReference resolveDefaultVocab(Model model, DeployState state) { + var modelId = model.modelId().orElse(null); + if (state.isHosted() && modelId != null) { + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); + } + throw new IllegalArgumentException("'tokenizer-model' must be specified"); + } + + + @Override + public void getConfig(SpladeEmbedderConfig.Builder b) { + b.transformerModel(modelRef).tokenizerPath(vocabRef); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (termScoreThreshold != null) b.termScoreThreshold(termScoreThreshold); + + onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(SpladeEmbedderConfig.TransformerExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); + } +} + diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 2f3b10742f3..919253977ca 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -80,7 +80,7 @@ ComponentDefinition = TypedComponentDefinition = attribute id { xsd:Name } & - (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder) & + (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder | SpladeEmbedder ) & GenericConfig* & Component* @@ -97,6 +97,18 @@ HuggingFaceEmbedder = OnnxModelExecutionParams & EmbedderPoolingStrategy +SpladeEmbedder = + attribute type { "splade-embedder" } & + element transformer-model { ModelReference } & + element tokenizer-model { ModelReference }? & + element max-tokens { xsd:positiveInteger }? & + element transformer-input-ids { xsd:string }? & + element transformer-attention-mask { xsd:string }? & + element transformer-token-type-ids { xsd:string }? & + element transformer-output { xsd:string }? & + element term-score-threshold { xsd:double }? & + OnnxModelExecutionParams + HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & element model { attribute language { xsd:string }? & ModelReference }+ diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 1840063d70d..59c29aefc6a 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -19,6 +19,21 @@ <pooling-strategy>mean</pooling-strategy> </component> + <component id="splade" type="splade-embedder"> + <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> + <max-tokens>1024</max-tokens> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <term-score-threshold>0.2</term-score-threshold> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>10</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <component id="hf-tokenizer" type="hugging-face-tokenizer"> <model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/> </component> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 138bef3ae73..2532a5be863 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -12,6 +12,7 @@ import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -24,6 +25,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; +import com.yahoo.vespa.model.container.component.SpladeEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; @@ -101,6 +103,23 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + @Test + void spladeEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertSpladeEmbedderComponentPresent(cluster); + + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(0.2, embedderCfg.termScoreThreshold()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + @Test void colBertEmbedder_selfhosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); var cluster = model.getContainerClusters().get("container"); @@ -113,6 +132,7 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + @Test void colBertEmbedder_hosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); var cluster = model.getContainerClusters().get("container"); @@ -266,13 +286,21 @@ public class EmbedderTestCase { } private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { - var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert")); assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); var cfgBuilder = new ColBertEmbedderConfig.Builder(); colbert.getConfig(cfgBuilder); return cfgBuilder.build(); } + private static SpladeEmbedderConfig assertSpladeEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var splade = (SpladeEmbedder) cluster.getComponentsMap().get(new ComponentId("splade")); + assertEquals("ai.vespa.embedding.SpladeEmbedder", splade.getClassId().getName()); + var cfgBuilder = new SpladeEmbedderConfig.Builder(); + splade.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); |