diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-09-21 19:31:36 +0200 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-09-21 19:31:36 +0200 |
commit | 3062f88ca79cfa802d0a0c4e36de55d1f9e8245d (patch) | |
tree | 93642a0e36c848a9a9e04ab51666ebf6296a954d /config-model/src/test | |
parent | d04a31c05a399df95717accfd656d4c68d78e611 (diff) |
Add config options + license
Diffstat (limited to 'config-model/src/test')
-rw-r--r-- | config-model/src/test/cfg/application/embed/services.xml | 19 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java | 33 |
2 files changed, 52 insertions, 0 deletions
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 70eef7ea54a..efb33d36761 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -43,6 +43,25 @@ <onnx-gpu-device>1</onnx-gpu-device> </component> + <component id="colbert" type="colbert-embedder"> + <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> + <max-tokens>1024</max-tokens> + <max-query-tokens>32</max-query-tokens> + <max-document-tokens>512</max-document-tokens> + <transformer-start-sequence-token>101</transformer-start-sequence-token> + <transformer-end-sequence-token>102</transformer-end-sequence-token> + <transformer-mask-token>103</transformer-mask-token> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>10</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 42b78db66b1..5832445d0d7 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -21,6 +22,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; @@ -96,6 +98,29 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + void colBertEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + void colBertEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } @Test void bertEmbedder_selfhosted() throws Exception { @@ -233,6 +258,14 @@ public class EmbedderTestCase { return cfgBuilder.build(); } + private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); + var cfgBuilder = new ColBertEmbedderConfig.Builder(); + colbert.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); |