diff options
-rw-r--r-- | config-model/src/test/cfg/application/embed/services.xml | 20 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java | 31 |
2 files changed, 51 insertions, 0 deletions
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 99c89bc4324..2b54d850452 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -25,6 +25,26 @@ <truncation>true</truncation> </component> + <component id="bert-embedder" type="bert-embedder"> + <!-- model specifics --> + <transformer-model model-id="minilm-l6-v2" url="application-url"/> + <tokenizer-vocab path="files/vocab.txt"/> + <max-tokens>512</max-tokens> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <transformer-start-sequence-token>101</transformer-start-sequence-token> + <transformer-end-sequence-token>102</transformer-end-sequence-token> + + + <!-- tunable parameters: number of threads etc --> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>4</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration"> <config name="embedding.bert-base-embedder"> <!-- model specifics --> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index ce0719b87d2..f2edd0d1dbf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.config.ModelReference; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -16,6 +17,7 @@ import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigPayloadBuilder; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; @@ -142,6 +144,27 @@ public class EmbedderTestCase { assertEquals(768, tokenizerCfg.maxLength()); } + + @Test + void bertEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertBertEmbedderComponentPresent(cluster); + assertEquals("application-url", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value()); + } + + @Test + void bertEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertBertEmbedderComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertTrue(modelReference(embedderCfg, "tokenizerVocab").url().isEmpty()); + assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value()); + } + @Test void passesXmlValdiation() { new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create(); @@ -270,6 +293,14 @@ public class EmbedderTestCase { return cfgBuilder.build(); } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); + assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); + var cfgBuilder = new BertBaseEmbedderConfig.Builder(); + bertEmbedder.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + // Ugly hack to read underlying model reference from config instance private static ModelReference modelReference(InnerNode cfg, String name) { try { |