summaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2022-08-23 09:40:18 +0200
committerGitHub <noreply@github.com>2022-08-23 09:40:18 +0200
commit7507b6700295d0eea7d55836e282e3fc833bfdc8 (patch)
tree1f84608612d93828b0e1523e4be4305cb433d0bf /config-model/src
parent8bba638ca93da728b9604ec828ae0457514c5bc6 (diff)
parent25d395d3df09d155d26664d1092a80e1412e17a0 (diff)
Merge pull request #23710 from vespa-engine/bratseth/embedder-syntax
Support path=.. in generic embedders
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java52
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java15
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java11
-rw-r--r--config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def25
-rw-r--r--config-model/src/test/cfg/application/embed_generic/services.xml20
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java172
6 files changed, 199 insertions, 96 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
index a2286647cdd..fa531176b9c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
@@ -26,26 +26,6 @@ import org.w3c.dom.NodeList;
*/
public class EmbedderConfig {
- static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) {
- String classId = getEmbedderClass(spec);
- switch (classId) {
- case "ai.vespa.embedding.BertBaseEmbedder": return new EmbedderConfigBertBaseTransformer(spec, hosted);
- }
- return new EmbedderConfigTransformer(spec, hosted);
- }
-
- static String modelIdToUrl(String id) {
- switch (id) {
- case "test-model-id":
- return "test-model-url";
- case "minilm-l6-v2":
- return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx";
- case "bert-base-uncased":
- return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt";
- }
- throw new IllegalArgumentException("Unknown model id: '" + id + "'");
- }
-
/**
* Transforms the &lt;embedder ...&gt; element to component configuration.
*
@@ -65,6 +45,36 @@ public class EmbedderConfig {
return transformer.createComponentConfig(deployState);
}
+ private static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) {
+ return switch (embedderConfigFrom(spec)) {
+ case "embedding.bert-base-embedder" -> new EmbedderConfigBertBaseTransformer(spec, hosted);
+ default -> new EmbedderConfigTransformer(spec, hosted);
+ };
+ }
+
+ private static String embedderConfigFrom(Element spec) {
+ String explicitDefinition = spec.getAttribute("def");
+ if ( ! explicitDefinition.isEmpty()) return explicitDefinition;
+
+ // Implicit from class name
+ return switch (getEmbedderClass(spec)) {
+ case "ai.vespa.embedding.BertBaseEmbedder" -> "embedding.bert-base-embedder";
+ default -> "";
+ };
+ }
+
+ static String modelIdToUrl(String id) {
+ switch (id) {
+ case "test-model-id":
+ return "test-model-url";
+ case "minilm-l6-v2":
+ return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx";
+ case "bert-base-uncased":
+ return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt";
+ }
+ throw new IllegalArgumentException("Unknown model id: '" + id + "'");
+ }
+
private static String getEmbedderClass(Element spec) {
if (spec.hasAttribute("class")) {
return spec.getAttribute("class");
@@ -72,7 +82,7 @@ public class EmbedderConfig {
if (spec.hasAttribute("id")) {
return spec.getAttribute("id");
}
- throw new IllegalArgumentException("Embedder specification does not have a required class attribute");
+ throw new IllegalArgumentException("An <embedder> element must have a 'class' or 'id' attribute");
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java
index 30327fdc8af..efb1aafdbe3 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java
@@ -9,6 +9,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.stream.Collectors;
/**
@@ -88,15 +89,13 @@ public class EmbedderConfigTransformer {
}
private void checkRequiredOptions() {
- List<String> missingOptions = new ArrayList<>();
- for (EmbedderOption option : options.values()) {
- if ( ! option.isSet()) {
- missingOptions.add(option.name());
- }
- }
- if (missingOptions.size() > 0) {
+ var missingOptions = options.values()
+ .stream()
+ .filter(option -> ! option.isSet())
+ .map(option -> option.name())
+ .collect(Collectors.toList());
+ if (missingOptions.size() > 0)
throw new IllegalArgumentException("Embedder '" + className + "' requires options for " + missingOptions);
- }
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
index 206745887d1..715c3d9ef34 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
@@ -58,12 +58,21 @@ public class EmbedderOption {
return set;
}
+ @Override
+ public String toString() {
+ return "embedder option '" + name + "'";
+ }
+
/**
* Basic option transformer. No special handling of options.
*/
public static class OptionTransformer {
+
public void transform(DeployState deployState, Element parent, EmbedderOption option) {
- createElement(parent, option.name(), option.value());
+ if (option.value().isEmpty())
+ createElement(parent, option.name(), option.attributes.get("path")); // always understand path=".."
+ else
+ createElement(parent, option.name(), option.value());
}
public static Element createElement(Element parent, String name, String value) {
diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
new file mode 100644
index 00000000000..f62e2019189
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
@@ -0,0 +1,25 @@
+package=ai.vespa.example.paragraph
+
+# Settings for wordpiece tokenizer
+vocab path
+
+# Transformer model settings
+model path
+
+# Max length of token sequence model can handle
+transforerMaxTokens int default=128
+
+# Pooling strategy
+poolingStrategy enum { cls, mean } default=mean
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+
+# Output name
+transformerOutput string default=last_hidden_state
+
+# Settings for ONNX model evaluation
+onnxExecutionMode enum { parallel, sequential } default=sequential
+onnxInterOpThreads int default=1
+onnxIntraOpThreads int default=-4
diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml
new file mode 100644
index 00000000000..ab2c1be9745
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed_generic/services.xml
@@ -0,0 +1,20 @@
+<?xml version="1.0" encoding="utf-8" ?>
+<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services version="1.0">
+
+ <container version="1.0">
+
+ <embedder id='transformer'
+ class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder'
+ bundle='exampleEmbedder'
+ def='ai.vespa.example.paragraph.sentence-embedder'>
+ <model path="files/model.onnx" /> <!-- Embedder syntax for file path -->
+ <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path -->
+ </embedder>
+
+ <nodes>
+ <node hostalias='node1'/>
+ </nodes>
+ </container>
+
+</services>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index fcccdcf8f23..766c2b11256 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -55,85 +55,113 @@ public class EmbedderTestCase {
@Test
void testPredefinedEmbedConfigSelfHosted() throws IOException, SAXException {
+ String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"my_model_id\" url=\"my-model-url\" />" +
+ " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" />" +
+ "</embedder>";
+ String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>my-model-url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component, false);
+ }
+
+ @Test
+ void testIncorrectEmbedderOptionsSelfHosted() throws IOException, SAXException {
assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\"></embedder>",
- "Embedder '" + PREDEFINED_EMBEDDER_CLASS + "' requires options for [vocab, model]");
+ "Embedder '" + PREDEFINED_EMBEDDER_CLASS + "' requires options for [vocab, model]");
assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
- " <model />" +
- " <vocab />" +
- "</embedder>",
- "Model option requires either a 'path' or a 'url' attribute");
+ " <model />" +
+ " <vocab />" +
+ "</embedder>",
+ "Model option requires either a 'path' or a 'url' attribute");
assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
- " <model id=\"my_model_id\" />" +
- " <vocab id=\"my_vocab_id\" />" +
- "</embedder>",
- "Model option 'id' is not valid here");
+ " <model id=\"my_model_id\" />" +
+ " <vocab id=\"my_vocab_id\" />" +
+ "</embedder>",
+ "Model option 'id' is not valid here");
+ }
+ @Test
+ void testPathHasprioritySelfHosted() throws IOException, SAXException {
String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
- " <model id=\"my_model_id\" url=\"my-model-url\" />" +
- " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" />" +
- "</embedder>";
+ " <model id=\"my_model_id\" url=\"my-model-url\" path=\"files/model.onnx\" />" +
+ " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" path=\"files/vocab.txt\" />" +
+ "</embedder>";
String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
- " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
- " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" +
- " <tokenizerVocabPath></tokenizerVocabPath>" +
- " <transformerModelUrl>my-model-url</transformerModelUrl>" +
- " <transformerModelPath></transformerModelPath>" +
- " </config>" +
- "</component>";
- assertTransform(embedder, component, false);
-
- // Path has priority:
- embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
- " <model id=\"my_model_id\" url=\"my-model-url\" path=\"files/model.onnx\" />" +
- " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" path=\"files/vocab.txt\" />" +
- "</embedder>";
- component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
- " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
- " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" +
- " <tokenizerVocabUrl></tokenizerVocabUrl>" +
- " <transformerModelPath>files/model.onnx</transformerModelPath>" +
- " <transformerModelUrl></transformerModelUrl>" +
- " </config>" +
- "</component>";
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" +
+ " <tokenizerVocabUrl></tokenizerVocabUrl>" +
+ " <transformerModelPath>files/model.onnx</transformerModelPath>" +
+ " <transformerModelUrl></transformerModelUrl>" +
+ " </config>" +
+ "</component>";
assertTransform(embedder, component, false);
}
@Test
- void testPredefinedEmbedConfigCloud() throws IOException, SAXException {
+ void testPredefinedEmptyEmbedConfigCloud() throws IOException, SAXException {
String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" />";
String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
- " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
- " <tokenizerVocabUrl>some url</tokenizerVocabUrl>" +
- " <tokenizerVocabPath></tokenizerVocabPath>" +
- " <transformerModelUrl>some url</transformerModelUrl>" +
- " <transformerModelPath></transformerModelPath>" +
- " </config>" +
- "</component>";
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>some url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>some url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
assertTransform(embedder, component, true);
+ }
- embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
- " <model id=\"my_model_id\" />" +
- " <vocab id=\"my_vocab_id\" />" +
- "</embedder>";
- assertTransformThrows(embedder, "Unknown model id: 'my_vocab_id'", true);
+ @Test
+ void testPredefinedEmbedConfigCloud() throws IOException, SAXException {
+ String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"test-model-id\" />" +
+ " <vocab id=\"test-model-id\" />" +
+ "</embedder>";
+ String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>test-model-url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component, true);
+ }
- embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
- " <model id=\"test-model-id\" />" +
- " <vocab id=\"test-model-id\" />" +
- "</embedder>";
- component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
- " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
- " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" +
- " <tokenizerVocabPath></tokenizerVocabPath>" +
- " <transformerModelUrl>test-model-url</transformerModelUrl>" +
- " <transformerModelPath></transformerModelPath>" +
- " </config>" +
- "</component>";
+ @Test
+ void testCustomEmbedderWithPredefinedConfigCloud() throws IOException, SAXException {
+ String embedder = "<embedder id=\"test\" class=\"ApplicationSpecificEmbedder\" def=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <model id=\"test-model-id\" />" +
+ " <vocab id=\"test-model-id\" />" +
+ "</embedder>";
+ String component = "<component id=\"test\" class=\"ApplicationSpecificEmbedder\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>test-model-url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
assertTransform(embedder, component, true);
}
@Test
- void testEmbedConfig() throws Exception {
+ void testUnknownModelIdCloud() throws IOException, SAXException {
+ String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"my_model_id\" />" +
+ " <vocab id=\"my_vocab_id\" />" +
+ "</embedder>";
+ assertTransformThrows(embedder, "Unknown model id: 'my_vocab_id'", true);
+ }
+
+ @Test
+ void testApplicationWithEmbedConfig() throws Exception {
final String emptyPathFileName = "services.xml";
Path applicationDir = Path.fromString("src/test/cfg/application/embed/");
@@ -142,15 +170,27 @@ public class EmbedderTestCase {
Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test"));
ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test"));
- assertEquals(testConfig.getObject("num").getValue(), "12");
- assertEquals(testConfig.getObject("str").getValue(), "some text");
+ assertEquals("12", testConfig.getObject("num").getValue());
+ assertEquals("some text", testConfig.getObject("str").getValue());
Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
- assertEquals(transformerConfig.getObject("transformerModelUrl").getValue(), "test-model-url");
- assertEquals(transformerConfig.getObject("transformerModelPath").getValue(), emptyPathFileName);
- assertEquals(transformerConfig.getObject("tokenizerVocabUrl").getValue(), "");
- assertEquals(transformerConfig.getObject("tokenizerVocabPath").getValue(), "files/vocab.txt");
+ assertEquals("test-model-url", transformerConfig.getObject("transformerModelUrl").getValue());
+ assertEquals(emptyPathFileName, transformerConfig.getObject("transformerModelPath").getValue());
+ assertEquals("", transformerConfig.getObject("tokenizerVocabUrl").getValue());
+ assertEquals("files/vocab.txt", transformerConfig.getObject("tokenizerVocabPath").getValue());
+ }
+
+ @Test
+ void testApplicationWithGenericEmbedConfig() throws Exception {
+ Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/");
+ VespaModel model = loadModel(applicationDir, false);
+ ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
+
+ Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
+ ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
+ assertEquals("files/vocab.txt", config.getObject("vocab").getValue());
+ assertEquals("files/model.onnx", config.getObject("model").getValue());
}
private VespaModel loadModel(Path path, boolean hosted) throws Exception {