aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/vespa/model/container
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-12-15 08:42:01 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2023-12-15 08:42:01 +0100
commitcbc0733c07b57d8563eea40897072cb35042b605 (patch)
treefa3b564329dc4a88b48c697692c7896d6d4b36b0 /config-model/src/test/java/com/yahoo/vespa/model/container
parent8af800ba588f726184ffb8296463bb4b7fbea5a1 (diff)
Add a splade embedder implementation
Diffstat (limited to 'config-model/src/test/java/com/yahoo/vespa/model/container')
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java30
1 files changed, 29 insertions, 1 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 138bef3ae73..2532a5be863 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -12,6 +12,7 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.embedding.ColBertEmbedderConfig;
+import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
@@ -24,6 +25,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.ColBertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
+import com.yahoo.vespa.model.container.component.SpladeEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
import com.yahoo.yolean.Exceptions;
@@ -101,6 +103,23 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ @Test
+ void spladeEmbedder_selfhosted() throws Exception {
+ var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
+ var cluster = model.getContainerClusters().get("container");
+ var embedderCfg = assertSpladeEmbedderComponentPresent(cluster);
+
+ assertEquals("my_input_ids", embedderCfg.transformerInputIds());
+ assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(0.2, embedderCfg.termScoreThreshold());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
+
+ var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
+ assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
+ assertEquals(-1, tokenizerCfg.maxLength());
+ }
+
+ @Test
void colBertEmbedder_selfhosted() throws Exception {
var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false);
var cluster = model.getContainerClusters().get("container");
@@ -113,6 +132,7 @@ public class EmbedderTestCase {
assertEquals(-1, tokenizerCfg.maxLength());
}
+ @Test
void colBertEmbedder_hosted() throws Exception {
var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true);
var cluster = model.getContainerClusters().get("container");
@@ -266,13 +286,21 @@ public class EmbedderTestCase {
}
private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
- var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder"));
+ var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert"));
assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName());
var cfgBuilder = new ColBertEmbedderConfig.Builder();
colbert.getConfig(cfgBuilder);
return cfgBuilder.build();
}
+ private static SpladeEmbedderConfig assertSpladeEmbedderComponentPresent(ApplicationContainerCluster cluster) {
+ var splade = (SpladeEmbedder) cluster.getComponentsMap().get(new ComponentId("splade"));
+ assertEquals("ai.vespa.embedding.SpladeEmbedder", splade.getClassId().getName());
+ var cfgBuilder = new SpladeEmbedderConfig.Builder();
+ splade.getConfig(cfgBuilder);
+ return cfgBuilder.build();
+ }
+
private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) {
var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder"));
assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName());