diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2022-08-23 09:40:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-23 09:40:18 +0200 |
commit | 7507b6700295d0eea7d55836e282e3fc833bfdc8 (patch) | |
tree | 1f84608612d93828b0e1523e4be4305cb433d0bf /config-model | |
parent | 8bba638ca93da728b9604ec828ae0457514c5bc6 (diff) | |
parent | 25d395d3df09d155d26664d1092a80e1412e17a0 (diff) |
Merge pull request #23710 from vespa-engine/bratseth/embedder-syntax
Support path=.. in generic embedders
Diffstat (limited to 'config-model')
6 files changed, 199 insertions, 96 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java index a2286647cdd..fa531176b9c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java @@ -26,26 +26,6 @@ import org.w3c.dom.NodeList; */ public class EmbedderConfig { - static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) { - String classId = getEmbedderClass(spec); - switch (classId) { - case "ai.vespa.embedding.BertBaseEmbedder": return new EmbedderConfigBertBaseTransformer(spec, hosted); - } - return new EmbedderConfigTransformer(spec, hosted); - } - - static String modelIdToUrl(String id) { - switch (id) { - case "test-model-id": - return "test-model-url"; - case "minilm-l6-v2": - return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"; - case "bert-base-uncased": - return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"; - } - throw new IllegalArgumentException("Unknown model id: '" + id + "'"); - } - /** * Transforms the <embedder ...> element to component configuration. * @@ -65,6 +45,36 @@ public class EmbedderConfig { return transformer.createComponentConfig(deployState); } + private static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) { + return switch (embedderConfigFrom(spec)) { + case "embedding.bert-base-embedder" -> new EmbedderConfigBertBaseTransformer(spec, hosted); + default -> new EmbedderConfigTransformer(spec, hosted); + }; + } + + private static String embedderConfigFrom(Element spec) { + String explicitDefinition = spec.getAttribute("def"); + if ( ! explicitDefinition.isEmpty()) return explicitDefinition; + + // Implicit from class name + return switch (getEmbedderClass(spec)) { + case "ai.vespa.embedding.BertBaseEmbedder" -> "embedding.bert-base-embedder"; + default -> ""; + }; + } + + static String modelIdToUrl(String id) { + switch (id) { + case "test-model-id": + return "test-model-url"; + case "minilm-l6-v2": + return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"; + case "bert-base-uncased": + return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"; + } + throw new IllegalArgumentException("Unknown model id: '" + id + "'"); + } + private static String getEmbedderClass(Element spec) { if (spec.hasAttribute("class")) { return spec.getAttribute("class"); @@ -72,7 +82,7 @@ public class EmbedderConfig { if (spec.hasAttribute("id")) { return spec.getAttribute("id"); } - throw new IllegalArgumentException("Embedder specification does not have a required class attribute"); + throw new IllegalArgumentException("An <embedder> element must have a 'class' or 'id' attribute"); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java index 30327fdc8af..efb1aafdbe3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; /** @@ -88,15 +89,13 @@ public class EmbedderConfigTransformer { } private void checkRequiredOptions() { - List<String> missingOptions = new ArrayList<>(); - for (EmbedderOption option : options.values()) { - if ( ! option.isSet()) { - missingOptions.add(option.name()); - } - } - if (missingOptions.size() > 0) { + var missingOptions = options.values() + .stream() + .filter(option -> ! option.isSet()) + .map(option -> option.name()) + .collect(Collectors.toList()); + if (missingOptions.size() > 0) throw new IllegalArgumentException("Embedder '" + className + "' requires options for " + missingOptions); - } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java index 206745887d1..715c3d9ef34 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java @@ -58,12 +58,21 @@ public class EmbedderOption { return set; } + @Override + public String toString() { + return "embedder option '" + name + "'"; + } + /** * Basic option transformer. No special handling of options. */ public static class OptionTransformer { + public void transform(DeployState deployState, Element parent, EmbedderOption option) { - createElement(parent, option.name(), option.value()); + if (option.value().isEmpty()) + createElement(parent, option.name(), option.attributes.get("path")); // always understand path=".." + else + createElement(parent, option.name(), option.value()); } public static Element createElement(Element parent, String name, String value) { diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def new file mode 100644 index 00000000000..f62e2019189 --- /dev/null +++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def @@ -0,0 +1,25 @@ +package=ai.vespa.example.paragraph + +# Settings for wordpiece tokenizer +vocab path + +# Transformer model settings +model path + +# Max length of token sequence model can handle +transforerMaxTokens int default=128 + +# Pooling strategy +poolingStrategy enum { cls, mean } default=mean + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask + +# Output name +transformerOutput string default=last_hidden_state + +# Settings for ONNX model evaluation +onnxExecutionMode enum { parallel, sequential } default=sequential +onnxInterOpThreads int default=1 +onnxIntraOpThreads int default=-4 diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml new file mode 100644 index 00000000000..ab2c1be9745 --- /dev/null +++ b/config-model/src/test/cfg/application/embed_generic/services.xml @@ -0,0 +1,20 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + + <embedder id='transformer' + class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' + bundle='exampleEmbedder' + def='ai.vespa.example.paragraph.sentence-embedder'> + <model path="files/model.onnx" /> <!-- Embedder syntax for file path --> + <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path --> + </embedder> + + <nodes> + <node hostalias='node1'/> + </nodes> + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index fcccdcf8f23..766c2b11256 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -55,85 +55,113 @@ public class EmbedderTestCase { @Test void testPredefinedEmbedConfigSelfHosted() throws IOException, SAXException { + String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"my_model_id\" url=\"my-model-url\" />" + + " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" />" + + "</embedder>"; + String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>my-model-url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; + assertTransform(embedder, component, false); + } + + @Test + void testIncorrectEmbedderOptionsSelfHosted() throws IOException, SAXException { assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\"></embedder>", - "Embedder '" + PREDEFINED_EMBEDDER_CLASS + "' requires options for [vocab, model]"); + "Embedder '" + PREDEFINED_EMBEDDER_CLASS + "' requires options for [vocab, model]"); assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + - " <model />" + - " <vocab />" + - "</embedder>", - "Model option requires either a 'path' or a 'url' attribute"); + " <model />" + + " <vocab />" + + "</embedder>", + "Model option requires either a 'path' or a 'url' attribute"); assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + - " <model id=\"my_model_id\" />" + - " <vocab id=\"my_vocab_id\" />" + - "</embedder>", - "Model option 'id' is not valid here"); + " <model id=\"my_model_id\" />" + + " <vocab id=\"my_vocab_id\" />" + + "</embedder>", + "Model option 'id' is not valid here"); + } + @Test + void testPathHasprioritySelfHosted() throws IOException, SAXException { String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + - " <model id=\"my_model_id\" url=\"my-model-url\" />" + - " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" />" + - "</embedder>"; + " <model id=\"my_model_id\" url=\"my-model-url\" path=\"files/model.onnx\" />" + + " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" path=\"files/vocab.txt\" />" + + "</embedder>"; String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + - " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + - " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + - " <transformerModelUrl>my-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " </config>" + - "</component>"; - assertTransform(embedder, component, false); - - // Path has priority: - embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + - " <model id=\"my_model_id\" url=\"my-model-url\" path=\"files/model.onnx\" />" + - " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" path=\"files/vocab.txt\" />" + - "</embedder>"; - component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + - " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + - " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" + - " <tokenizerVocabUrl></tokenizerVocabUrl>" + - " <transformerModelPath>files/model.onnx</transformerModelPath>" + - " <transformerModelUrl></transformerModelUrl>" + - " </config>" + - "</component>"; + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" + + " <tokenizerVocabUrl></tokenizerVocabUrl>" + + " <transformerModelPath>files/model.onnx</transformerModelPath>" + + " <transformerModelUrl></transformerModelUrl>" + + " </config>" + + "</component>"; assertTransform(embedder, component, false); } @Test - void testPredefinedEmbedConfigCloud() throws IOException, SAXException { + void testPredefinedEmptyEmbedConfigCloud() throws IOException, SAXException { String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" />"; String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + - " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + - " <tokenizerVocabUrl>some url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + - " <transformerModelUrl>some url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " </config>" + - "</component>"; + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>some url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>some url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; assertTransform(embedder, component, true); + } - embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + - " <model id=\"my_model_id\" />" + - " <vocab id=\"my_vocab_id\" />" + - "</embedder>"; - assertTransformThrows(embedder, "Unknown model id: 'my_vocab_id'", true); + @Test + void testPredefinedEmbedConfigCloud() throws IOException, SAXException { + String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"test-model-id\" />" + + " <vocab id=\"test-model-id\" />" + + "</embedder>"; + String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>test-model-url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; + assertTransform(embedder, component, true); + } - embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + - " <model id=\"test-model-id\" />" + - " <vocab id=\"test-model-id\" />" + - "</embedder>"; - component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + - " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + - " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + - " <transformerModelUrl>test-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " </config>" + - "</component>"; + @Test + void testCustomEmbedderWithPredefinedConfigCloud() throws IOException, SAXException { + String embedder = "<embedder id=\"test\" class=\"ApplicationSpecificEmbedder\" def=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <model id=\"test-model-id\" />" + + " <vocab id=\"test-model-id\" />" + + "</embedder>"; + String component = "<component id=\"test\" class=\"ApplicationSpecificEmbedder\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>test-model-url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; assertTransform(embedder, component, true); } @Test - void testEmbedConfig() throws Exception { + void testUnknownModelIdCloud() throws IOException, SAXException { + String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"my_model_id\" />" + + " <vocab id=\"my_vocab_id\" />" + + "</embedder>"; + assertTransformThrows(embedder, "Unknown model id: 'my_vocab_id'", true); + } + + @Test + void testApplicationWithEmbedConfig() throws Exception { final String emptyPathFileName = "services.xml"; Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); @@ -142,15 +170,27 @@ public class EmbedderTestCase { Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test")); ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test")); - assertEquals(testConfig.getObject("num").getValue(), "12"); - assertEquals(testConfig.getObject("str").getValue(), "some text"); + assertEquals("12", testConfig.getObject("num").getValue()); + assertEquals("some text", testConfig.getObject("str").getValue()); Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); - assertEquals(transformerConfig.getObject("transformerModelUrl").getValue(), "test-model-url"); - assertEquals(transformerConfig.getObject("transformerModelPath").getValue(), emptyPathFileName); - assertEquals(transformerConfig.getObject("tokenizerVocabUrl").getValue(), ""); - assertEquals(transformerConfig.getObject("tokenizerVocabPath").getValue(), "files/vocab.txt"); + assertEquals("test-model-url", transformerConfig.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, transformerConfig.getObject("transformerModelPath").getValue()); + assertEquals("", transformerConfig.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", transformerConfig.getObject("tokenizerVocabPath").getValue()); + } + + @Test + void testApplicationWithGenericEmbedConfig() throws Exception { + Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); + VespaModel model = loadModel(applicationDir, false); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + + Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); + assertEquals("files/vocab.txt", config.getObject("vocab").getValue()); + assertEquals("files/model.onnx", config.getObject("model").getValue()); } private VespaModel loadModel(Path path, boolean hosted) throws Exception { |