aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java6
-rw-r--r--config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def25
-rw-r--r--config-model/src/test/cfg/application/embed_generic/services.xml24
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java26
4 files changed, 73 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
index 206745887d1..d8740eecd74 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
@@ -62,8 +62,12 @@ public class EmbedderOption {
* Basic option transformer. No special handling of options.
*/
public static class OptionTransformer {
+
public void transform(DeployState deployState, Element parent, EmbedderOption option) {
- createElement(parent, option.name(), option.value());
+ if (option.value().isEmpty())
+ createElement(parent, option.name(), option.attributes.get("path")); // always understand path=".."
+ else
+ createElement(parent, option.name(), option.value());
}
public static Element createElement(Element parent, String name, String value) {
diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
new file mode 100644
index 00000000000..f62e2019189
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
@@ -0,0 +1,25 @@
+package=ai.vespa.example.paragraph
+
+# Settings for wordpiece tokenizer
+vocab path
+
+# Transformer model settings
+model path
+
+# Max length of token sequence model can handle
+transforerMaxTokens int default=128
+
+# Pooling strategy
+poolingStrategy enum { cls, mean } default=mean
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+
+# Output name
+transformerOutput string default=last_hidden_state
+
+# Settings for ONNX model evaluation
+onnxExecutionMode enum { parallel, sequential } default=sequential
+onnxInterOpThreads int default=1
+onnxIntraOpThreads int default=-4
diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml
new file mode 100644
index 00000000000..7947b8b707c
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed_generic/services.xml
@@ -0,0 +1,24 @@
+<?xml version="1.0" encoding="utf-8" ?>
+<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services version="1.0">
+
+ <container version="1.0">
+
+ <!--
+ <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'>
+ <model>files/model.onnx</model>
+ <vocab>files/vocab.txt</vocab>
+ </embedder>
+ -->
+
+ <embedder id='transformer' class='ai.vespa.example.paragraph.SentenceBertEmbedder' bundle='exampleEmbedder' def='ai.vespa.example.paragraph.sentence-embedder'>
+ <model path="files/model.onnx" /> <!-- Embedder syntax for file path -->
+ <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path -->
+ </embedder>
+
+ <nodes>
+ <node hostalias='node1'/>
+ </nodes>
+ </container>
+
+</services>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index fcccdcf8f23..6aff0a7f003 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -133,7 +133,7 @@ public class EmbedderTestCase {
}
@Test
- void testEmbedConfig() throws Exception {
+ void testApplicationWithEmbedConfig() throws Exception {
final String emptyPathFileName = "services.xml";
Path applicationDir = Path.fromString("src/test/cfg/application/embed/");
@@ -142,15 +142,27 @@ public class EmbedderTestCase {
Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test"));
ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test"));
- assertEquals(testConfig.getObject("num").getValue(), "12");
- assertEquals(testConfig.getObject("str").getValue(), "some text");
+ assertEquals("12", testConfig.getObject("num").getValue());
+ assertEquals("some text", testConfig.getObject("str").getValue());
Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
- assertEquals(transformerConfig.getObject("transformerModelUrl").getValue(), "test-model-url");
- assertEquals(transformerConfig.getObject("transformerModelPath").getValue(), emptyPathFileName);
- assertEquals(transformerConfig.getObject("tokenizerVocabUrl").getValue(), "");
- assertEquals(transformerConfig.getObject("tokenizerVocabPath").getValue(), "files/vocab.txt");
+ assertEquals("test-model-url", transformerConfig.getObject("transformerModelUrl").getValue());
+ assertEquals(emptyPathFileName, transformerConfig.getObject("transformerModelPath").getValue());
+ assertEquals("", transformerConfig.getObject("tokenizerVocabUrl").getValue());
+ assertEquals("files/vocab.txt", transformerConfig.getObject("tokenizerVocabPath").getValue());
+ }
+
+ @Test
+ void testApplicationWithGenericEmbedConfig() throws Exception {
+ Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/");
+ VespaModel model = loadModel(applicationDir, false);
+ ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
+
+ Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
+ ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
+ assertEquals("files/vocab.txt", testConfig.getObject("vocab").getValue());
+ assertEquals("files/model.onnx", testConfig.getObject("model").getValue());
}
private VespaModel loadModel(Path path, boolean hosted) throws Exception {