diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-15 08:42:01 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-15 08:42:01 +0100 |
commit | cbc0733c07b57d8563eea40897072cb35042b605 (patch) | |
tree | fa3b564329dc4a88b48c697692c7896d6d4b36b0 /config-model/src/test/java/com/yahoo/vespa/model/container | |
parent | 8af800ba588f726184ffb8296463bb4b7fbea5a1 (diff) |
Add a splade embedder implementation
Diffstat (limited to 'config-model/src/test/java/com/yahoo/vespa/model/container')
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 138bef3ae73..2532a5be863 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -12,6 +12,7 @@ import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -24,6 +25,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; +import com.yahoo.vespa.model.container.component.SpladeEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; @@ -101,6 +103,23 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + @Test + void spladeEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertSpladeEmbedderComponentPresent(cluster); + + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(0.2, embedderCfg.termScoreThreshold()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + @Test void colBertEmbedder_selfhosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); var cluster = model.getContainerClusters().get("container"); @@ -113,6 +132,7 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + @Test void colBertEmbedder_hosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); var cluster = model.getContainerClusters().get("container"); @@ -266,13 +286,21 @@ public class EmbedderTestCase { } private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { - var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert")); assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); var cfgBuilder = new ColBertEmbedderConfig.Builder(); colbert.getConfig(cfgBuilder); return cfgBuilder.build(); } + private static SpladeEmbedderConfig assertSpladeEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var splade = (SpladeEmbedder) cluster.getComponentsMap().get(new ComponentId("splade")); + assertEquals("ai.vespa.embedding.SpladeEmbedder", splade.getClassId().getName()); + var cfgBuilder = new SpladeEmbedderConfig.Builder(); + splade.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); |