diff options
author | Jo Kristian Bergum <bergum@yahoo-inc.com> | 2024-01-04 15:45:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-04 15:45:02 +0100 |
commit | 6d2226e2bf35e32cf618ff12a7e2968c85eabf1f (patch) | |
tree | 5ac11755dbe8998c2b3f4dc5cae8859b0a2e1b9f /config-model/src/test | |
parent | b10d1fd87d7013847b19fc89a620fbe9c7136e61 (diff) | |
parent | 79bb01aa94375b6b9ce464fbdc5db24d1549e7d9 (diff) |
Merge pull request #29667 from vespa-engine/jobergum/splade-embedder
Add SPLADE embedder
Diffstat (limited to 'config-model/src/test')
-rw-r--r-- | config-model/src/test/cfg/application/embed/services.xml | 15 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java | 30 |
2 files changed, 44 insertions, 1 deletions
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 1840063d70d..59c29aefc6a 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -19,6 +19,21 @@ <pooling-strategy>mean</pooling-strategy> </component> + <component id="splade" type="splade-embedder"> + <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> + <max-tokens>1024</max-tokens> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <term-score-threshold>0.2</term-score-threshold> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>10</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <component id="hf-tokenizer" type="hugging-face-tokenizer"> <model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/> </component> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 138bef3ae73..2532a5be863 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -12,6 +12,7 @@ import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -24,6 +25,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; +import com.yahoo.vespa.model.container.component.SpladeEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; @@ -101,6 +103,23 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + @Test + void spladeEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertSpladeEmbedderComponentPresent(cluster); + + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(0.2, embedderCfg.termScoreThreshold()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + @Test void colBertEmbedder_selfhosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); var cluster = model.getContainerClusters().get("container"); @@ -113,6 +132,7 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + @Test void colBertEmbedder_hosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); var cluster = model.getContainerClusters().get("container"); @@ -266,13 +286,21 @@ public class EmbedderTestCase { } private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { - var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert")); assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); var cfgBuilder = new ColBertEmbedderConfig.Builder(); colbert.getConfig(cfgBuilder); return cfgBuilder.build(); } + private static SpladeEmbedderConfig assertSpladeEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var splade = (SpladeEmbedder) cluster.getComponentsMap().get(new ComponentId("splade")); + assertEquals("ai.vespa.embedding.SpladeEmbedder", splade.getClassId().getName()); + var cfgBuilder = new SpladeEmbedderConfig.Builder(); + splade.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); |