summaryrefslogtreecommitdiffstats
path: root/config-model/src/test
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-09-21 19:31:36 +0200
committerJo Kristian Bergum <bergum@yahooinc.com>2023-09-21 19:31:36 +0200
commit3062f88ca79cfa802d0a0c4e36de55d1f9e8245d (patch)
tree93642a0e36c848a9a9e04ab51666ebf6296a954d /config-model/src/test
parentd04a31c05a399df95717accfd656d4c68d78e611 (diff)
Add config options + license
Diffstat (limited to 'config-model/src/test')
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml19
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java33
2 files changed, 52 insertions, 0 deletions
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 70eef7ea54a..efb33d36761 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -43,6 +43,25 @@
<onnx-gpu-device>1</onnx-gpu-device>
</component>
+ <component id="colbert" type="colbert-embedder">
+ <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
+ <max-tokens>1024</max-tokens>
+ <max-query-tokens>32</max-query-tokens>
+ <max-document-tokens>512</max-document-tokens>
+ <transformer-start-sequence-token>101</transformer-start-sequence-token>
+ <transformer-end-sequence-token>102</transformer-end-sequence-token>
+ <transformer-mask-token>103</transformer-mask-token>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>10</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
+ </component>
+
<nodes>
<node hostalias="node1" />
</nodes>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 42b78db66b1..5832445d0d7 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.embedding.BertBaseEmbedderConfig;
+import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -21,6 +22,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
+import com.yahoo.vespa.model.container.component.ColBertEmbedder;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
import com.yahoo.yolean.Exceptions;
import org.junit.jupiter.api.Test;
@@ -96,6 +98,29 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ void colBertEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertColBertEmbedderComponentPresent(cluster);
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
+
+ void colBertEmbedder_hosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertColBertEmbedderComponentPresent(cluster);
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
@Test
void bertEmbedder_selfhosted() throws Exception {
@@ -233,6 +258,14 @@ public class EmbedderTestCase {
return cfgBuilder.build();
}
+ private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder"));
+ assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName());
+ var cfgBuilder = new ColBertEmbedderConfig.Builder();
+ colbert.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());