summaryrefslogtreecommitdiffstats
path: root/config-model/src/test
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahoo-inc.com>2024-01-04 15:45:02 +0100
committerGitHub <noreply@github.com>2024-01-04 15:45:02 +0100
commit6d2226e2bf35e32cf618ff12a7e2968c85eabf1f (patch)
tree5ac11755dbe8998c2b3f4dc5cae8859b0a2e1b9f /config-model/src/test
parentb10d1fd87d7013847b19fc89a620fbe9c7136e61 (diff)
parent79bb01aa94375b6b9ce464fbdc5db24d1549e7d9 (diff)
Merge pull request #29667 from vespa-engine/jobergum/splade-embedder
Add SPLADE embedder
Diffstat (limited to 'config-model/src/test')
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml15
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java30
2 files changed, 44 insertions, 1 deletions
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 1840063d70d..59c29aefc6a 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -19,6 +19,21 @@
<pooling-strategy>mean</pooling-strategy>
</component>
+ <component id="splade" type="splade-embedder">
+ <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
+ <max-tokens>1024</max-tokens>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <term-score-threshold>0.2</term-score-threshold>
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>10</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
+ </component>
+
<component id="hf-tokenizer" type="hugging-face-tokenizer">
<model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/>
</component>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 138bef3ae73..2532a5be863 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -12,6 +12,7 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.embedding.ColBertEmbedderConfig;
+import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -24,6 +25,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.ColBertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
+import com.yahoo.vespa.model.container.component.SpladeEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
import com.yahoo.yolean.Exceptions;
@@ -101,6 +103,23 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ @Test
+ void spladeEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertSpladeEmbedderComponentPresent(cluster);
+
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(0.2, embedderCfg.termScoreThreshold());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
+
+ @Test
void colBertEmbedder_selfhosted() throws Exception {
var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
var cluster = model.getContainerClusters().get("container");
@@ -113,6 +132,7 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ @Test
void colBertEmbedder_hosted() throws Exception {
var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
var cluster = model.getContainerClusters().get("container");
@@ -266,13 +286,21 @@ public class EmbedderTestCase {
}
private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
- var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder"));
+ var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert"));
assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName());
var cfgBuilder = new ColBertEmbedderConfig.Builder();
colbert.getConfig(cfgBuilder);
return cfgBuilder.build();
}
+ private static SpladeEmbedderConfig assertSpladeEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var splade = (SpladeEmbedder) cluster.getComponentsMap().get(new ComponentId("splade"));
+ assertEquals("ai.vespa.embedding.SpladeEmbedder", splade.getClassId().getName());
+ var cfgBuilder = new SpladeEmbedderConfig.Builder();
+ splade.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());