aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java20
-rw-r--r--config-model/src/test/cfg/application/embed_generic/services.xml12
-rw-r--r--config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml20
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java22
4 files changed, 58 insertions, 16 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
index a2286647cdd..69343643ef3 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
@@ -27,11 +27,21 @@ import org.w3c.dom.NodeList;
public class EmbedderConfig {
static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) {
- String classId = getEmbedderClass(spec);
- switch (classId) {
- case "ai.vespa.embedding.BertBaseEmbedder": return new EmbedderConfigBertBaseTransformer(spec, hosted);
- }
- return new EmbedderConfigTransformer(spec, hosted);
+ return switch (embedderConfigFrom(spec)) {
+ case "embedding.bert-base-embedder" -> new EmbedderConfigBertBaseTransformer(spec, hosted);
+ default -> new EmbedderConfigTransformer(spec, hosted);
+ };
+ }
+
+ private static String embedderConfigFrom(Element spec) {
+ String explicitDefinition = spec.getAttribute("def");
+ if ( ! explicitDefinition.isEmpty()) return explicitDefinition;
+
+ // Implicit from class name
+ return switch (spec.getAttribute("class")) {
+ case "ai.vespa.embedding.BertBaseEmbedder" -> "embedding.bert-base-embedder";
+ default -> "";
+ };
}
static String modelIdToUrl(String id) {
diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml
index 7947b8b707c..ab2c1be9745 100644
--- a/config-model/src/test/cfg/application/embed_generic/services.xml
+++ b/config-model/src/test/cfg/application/embed_generic/services.xml
@@ -4,14 +4,10 @@
<container version="1.0">
- <!--
- <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'>
- <model>files/model.onnx</model>
- <vocab>files/vocab.txt</vocab>
- </embedder>
- -->
-
- <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'>
+ <embedder id='transformer'
+ class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder'
+ bundle='exampleEmbedder'
+ def='ai.vespa.example.paragraph.sentence-embedder'>
<model path="files/model.onnx" /> <!-- Embedder syntax for file path -->
<vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path -->
</embedder>
diff --git a/config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml b/config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml
new file mode 100644
index 00000000000..4190e5c9286
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml
@@ -0,0 +1,20 @@
+<?xml version="1.0" encoding="utf-8" ?>
+<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services version="1.0">
+
+ <container version="1.0">
+
+ <embedder id='transformer'
+ class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder'
+ bundle='exampleEmbedder'
+ def='embedding.bert-base-embedder'>
+ <model id="test-model-id" url="test-model-url"/>
+ <vocab path="files/vocab.txt"/>
+ </embedder>
+
+ <nodes>
+ <node hostalias='node1'/>
+ </nodes>
+ </container>
+
+</services>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 6aff0a7f003..2a92963018d 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -160,9 +160,25 @@ public class EmbedderTestCase {
ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
- ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
- assertEquals("files/vocab.txt", testConfig.getObject("vocab").getValue());
- assertEquals("files/model.onnx", testConfig.getObject("model").getValue());
+ ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
+ assertEquals("files/vocab.txt", config.getObject("vocab").getValue());
+ assertEquals("files/model.onnx", config.getObject("model").getValue());
+ }
+
+ @Test
+ void testApplicationWithGenericEmbedConfigUsingProvidedModel() throws Exception {
+ final String emptyPathFileName = "services.xml";
+
+ Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic_using_provided_model/");
+ VespaModel model = loadModel(applicationDir, false);
+ ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
+
+ Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
+ ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
+ assertEquals("test-model-url", config.getObject("transformerModelUrl").getValue());
+ assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue());
+ assertEquals("", config.getObject("tokenizerVocabUrl").getValue());
+ assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue());
}
private VespaModel loadModel(Path path, boolean hosted) throws Exception {