diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-08-22 15:26:59 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-08-22 15:26:59 +0200 |
commit | 9071773dd980b82ed143daa7b874f0537fea2069 (patch) | |
tree | 27d0ed3718a86c59a658ea43d8f333263af2760c /config-model | |
parent | 8cc8437937bcaff5b9fb8338044ce447a5c51b32 (diff) |
Application embedders can reuse vespa embedder configs
Diffstat (limited to 'config-model')
4 files changed, 58 insertions, 16 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java index a2286647cdd..69343643ef3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java @@ -27,11 +27,21 @@ import org.w3c.dom.NodeList; public class EmbedderConfig { static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) { - String classId = getEmbedderClass(spec); - switch (classId) { - case "ai.vespa.embedding.BertBaseEmbedder": return new EmbedderConfigBertBaseTransformer(spec, hosted); - } - return new EmbedderConfigTransformer(spec, hosted); + return switch (embedderConfigFrom(spec)) { + case "embedding.bert-base-embedder" -> new EmbedderConfigBertBaseTransformer(spec, hosted); + default -> new EmbedderConfigTransformer(spec, hosted); + }; + } + + private static String embedderConfigFrom(Element spec) { + String explicitDefinition = spec.getAttribute("def"); + if ( ! explicitDefinition.isEmpty()) return explicitDefinition; + + // Implicit from class name + return switch (spec.getAttribute("class")) { + case "ai.vespa.embedding.BertBaseEmbedder" -> "embedding.bert-base-embedder"; + default -> ""; + }; } static String modelIdToUrl(String id) { diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml index 7947b8b707c..ab2c1be9745 100644 --- a/config-model/src/test/cfg/application/embed_generic/services.xml +++ b/config-model/src/test/cfg/application/embed_generic/services.xml @@ -4,14 +4,10 @@ <container version="1.0"> - <!-- - <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'> - <model>files/model.onnx</model> - <vocab>files/vocab.txt</vocab> - </embedder> - --> - - <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'> + <embedder id='transformer' + class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' + bundle='exampleEmbedder' + def='ai.vespa.example.paragraph.sentence-embedder'> <model path="files/model.onnx" /> <!-- Embedder syntax for file path --> <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path --> </embedder> diff --git a/config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml b/config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml new file mode 100644 index 00000000000..4190e5c9286 --- /dev/null +++ b/config-model/src/test/cfg/application/embed_generic_using_provided_model/services.xml @@ -0,0 +1,20 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + + <embedder id='transformer' + class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' + bundle='exampleEmbedder' + def='embedding.bert-base-embedder'> + <model id="test-model-id" url="test-model-url"/> + <vocab path="files/vocab.txt"/> + </embedder> + + <nodes> + <node hostalias='node1'/> + </nodes> + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 6aff0a7f003..2a92963018d 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -160,9 +160,25 @@ public class EmbedderTestCase { ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); - ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); - assertEquals("files/vocab.txt", testConfig.getObject("vocab").getValue()); - assertEquals("files/model.onnx", testConfig.getObject("model").getValue()); + ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); + assertEquals("files/vocab.txt", config.getObject("vocab").getValue()); + assertEquals("files/model.onnx", config.getObject("model").getValue()); + } + + @Test + void testApplicationWithGenericEmbedConfigUsingProvidedModel() throws Exception { + final String emptyPathFileName = "services.xml"; + + Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic_using_provided_model/"); + VespaModel model = loadModel(applicationDir, false); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + + Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); + assertEquals("test-model-url", config.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue()); + assertEquals("", config.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue()); } private VespaModel loadModel(Path path, boolean hosted) throws Exception { |