diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-08-18 13:41:11 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-08-18 13:41:11 +0200 |
commit | 8cc8437937bcaff5b9fb8338044ce447a5c51b32 (patch) | |
tree | bbe4fe9e96685a668d93274c32ecab0c26f37a4e /config-model/src | |
parent | c6d747e7f832bd24b21216af6437bb76f62c51ef (diff) |
Support path=.. in generic embedders
Diffstat (limited to 'config-model/src')
4 files changed, 73 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java index 206745887d1..d8740eecd74 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java @@ -62,8 +62,12 @@ public class EmbedderOption { * Basic option transformer. No special handling of options. */ public static class OptionTransformer { + public void transform(DeployState deployState, Element parent, EmbedderOption option) { - createElement(parent, option.name(), option.value()); + if (option.value().isEmpty()) + createElement(parent, option.name(), option.attributes.get("path")); // always understand path=".." + else + createElement(parent, option.name(), option.value()); } public static Element createElement(Element parent, String name, String value) { diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def new file mode 100644 index 00000000000..f62e2019189 --- /dev/null +++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def @@ -0,0 +1,25 @@ +package=ai.vespa.example.paragraph + +# Settings for wordpiece tokenizer +vocab path + +# Transformer model settings +model path + +# Max length of token sequence model can handle +transforerMaxTokens int default=128 + +# Pooling strategy +poolingStrategy enum { cls, mean } default=mean + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask + +# Output name +transformerOutput string default=last_hidden_state + +# Settings for ONNX model evaluation +onnxExecutionMode enum { parallel, sequential } default=sequential +onnxInterOpThreads int default=1 +onnxIntraOpThreads int default=-4 diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml new file mode 100644 index 00000000000..7947b8b707c --- /dev/null +++ b/config-model/src/test/cfg/application/embed_generic/services.xml @@ -0,0 +1,24 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + + <!-- + <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'> + <model>files/model.onnx</model> + <vocab>files/vocab.txt</vocab> + </embedder> + --> + + <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'> + <model path="files/model.onnx" /> <!-- Embedder syntax for file path --> + <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path --> + </embedder> + + <nodes> + <node hostalias='node1'/> + </nodes> + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index fcccdcf8f23..6aff0a7f003 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -133,7 +133,7 @@ public class EmbedderTestCase { } @Test - void testEmbedConfig() throws Exception { + void testApplicationWithEmbedConfig() throws Exception { final String emptyPathFileName = "services.xml"; Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); @@ -142,15 +142,27 @@ public class EmbedderTestCase { Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test")); ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test")); - assertEquals(testConfig.getObject("num").getValue(), "12"); - assertEquals(testConfig.getObject("str").getValue(), "some text"); + assertEquals("12", testConfig.getObject("num").getValue()); + assertEquals("some text", testConfig.getObject("str").getValue()); Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); - assertEquals(transformerConfig.getObject("transformerModelUrl").getValue(), "test-model-url"); - assertEquals(transformerConfig.getObject("transformerModelPath").getValue(), emptyPathFileName); - assertEquals(transformerConfig.getObject("tokenizerVocabUrl").getValue(), ""); - assertEquals(transformerConfig.getObject("tokenizerVocabPath").getValue(), "files/vocab.txt"); + assertEquals("test-model-url", transformerConfig.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, transformerConfig.getObject("transformerModelPath").getValue()); + assertEquals("", transformerConfig.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", transformerConfig.getObject("tokenizerVocabPath").getValue()); + } + + @Test + void testApplicationWithGenericEmbedConfig() throws Exception { + Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); + VespaModel model = loadModel(applicationDir, false); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + + Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); + assertEquals("files/vocab.txt", testConfig.getObject("vocab").getValue()); + assertEquals("files/model.onnx", testConfig.getObject("model").getValue()); } private VespaModel loadModel(Path path, boolean hosted) throws Exception { |