aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml20
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java31
2 files changed, 51 insertions, 0 deletions
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 99c89bc4324..2b54d850452 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -25,6 +25,26 @@
<truncation>true</truncation>
</component>
+ <component id="bert-embedder" type="bert-embedder">
+ <!-- model specifics -->
+ <transformer-model model-id="minilm-l6-v2" url="application-url"/>
+ <tokenizer-vocab path="files/vocab.txt"/>
+ <max-tokens>512</max-tokens>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <transformer-start-sequence-token>101</transformer-start-sequence-token>
+ <transformer-end-sequence-token>102</transformer-end-sequence-token>
+
+
+ <!-- tunable parameters: number of threads etc -->
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>4</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
+ </component>
+
<component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration">
<config name="embedding.bert-base-embedder">
<!-- model specifics -->
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index ce0719b87d2..f2edd0d1dbf 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.config.ModelReference;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -16,6 +17,7 @@ import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.vespa.config.ConfigPayloadBuilder;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
@@ -142,6 +144,27 @@ public class EmbedderTestCase {
assertEquals(768, tokenizerCfg.maxLength());
}
+
+ @Test
+ void bertEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertBertEmbedderComponentPresent(cluster);
+ assertEquals("application-url", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value());
+ }
+
+ @Test
+ void bertEmbedder_hosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertBertEmbedderComponentPresent(cluster);
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
+ modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertTrue(modelReference(embedderCfg, "tokenizerVocab").url().isEmpty());
+ assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value());
+ }
+
@Test
void passesXmlValdiation() {
new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create();
@@ -270,6 +293,14 @@ public class EmbedderTestCase {
return cfgBuilder.build();
}
+ private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
+ assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());
+ var cfgBuilder = new BertBaseEmbedderConfig.Builder();
+ bertEmbedder.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
// Ugly hack to read underlying model reference from config instance
private static ModelReference modelReference(InnerNode cfg, String name) {
try {