diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-08-26 09:58:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-26 09:58:10 +0200 |
commit | 40bb8680dbef01e603b8947a194c86e9acc14e30 (patch) | |
tree | 3badd5c97ff41449514805921e567c218661ab79 /config-model | |
parent | d227d62f0cef26ebdb30c0d5280a2462cd39767d (diff) | |
parent | ffab68b3f5c28034eaf3a606c1b220c14f7204fa (diff) |
Merge pull request #23770 from vespa-engine/bratseth/embedder-syntax-5
Bratseth/embedder syntax 5
Diffstat (limited to 'config-model')
11 files changed, 214 insertions, 251 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 213c2ca5df0..fc8a542b81c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -184,7 +184,6 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { addConfiguredComponents(deployState, cluster, spec); addSecretStore(cluster, spec, deployState); - addEmbedderComponents(deployState, cluster, spec); addModelEvaluation(spec, cluster, context); addModelEvaluationBundles(cluster); @@ -352,19 +351,12 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { container.setProp("rotations", String.join(",", rotationsProperty)); } - private static void addEmbedderComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) { - for (Element node : XML.getChildren(spec, "embedder")) { - Element transformed = EmbedderConfigTransformer.transform(deployState, node); - cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, transformed)); - } - } - - private void addConfiguredComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) { - for (Element components : XML.getChildren(spec, "components")) { + private void addConfiguredComponents(DeployState deployState, ApplicationContainerCluster cluster, Element parent) { + for (Element components : XML.getChildren(parent, "components")) { addIncludes(components); addConfiguredComponents(deployState, cluster, components, "component"); } - addConfiguredComponents(deployState, cluster, spec, "component"); + addConfiguredComponents(deployState, cluster, parent, "component"); } protected void addStatusHandlers(ApplicationContainerCluster cluster, boolean isHostedVespa) { @@ -963,9 +955,10 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } private static void addConfiguredComponents(DeployState deployState, ContainerCluster<? extends Container> cluster, - Element spec, String componentName) { - for (Element node : XML.getChildren(spec, componentName)) { - cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, node)); + Element parent, String componentName) { + for (Element component : XML.getChildren(parent, componentName)) { + component = ModelConfigTransformer.transform(deployState, component); + cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, component)); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/EmbedderConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/EmbedderConfigTransformer.java deleted file mode 100644 index 82ce8070c29..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/EmbedderConfigTransformer.java +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.xml; - -import com.yahoo.config.model.deploy.DeployState; -import com.yahoo.text.XML; -import org.w3c.dom.Element; - -/** - * Translates config in services.xml of the form - * - * <embedder id="..." class="..." bundle="..." def="..."> - * <!-- options --> - * </embedder> - * - * to component configuration of the form - * - * <component id="..." class="..." bundle="..."> - * <config name=def> - * <!-- options --> - * </config> - * </component> - * - * with some added interpretations based on recognizing the class. - * - * @author lesters - * @author bratseth - */ -public class EmbedderConfigTransformer { - - // Until we have optional path parameters, use services.xml as it is guaranteed to exist - private final static String dummyPath = "services.xml"; - - /** - * Transforms the <embedder ...> element to component configuration. - * - * @param deployState the deploy state - as config generation can depend on context - * @param embedder the XML element containing the <embedder ...> - * @return a new XML element containting the <component ...> configuration - */ - public static Element transform(DeployState deployState, Element embedder) { - Element component = XML.getDocumentBuilder().newDocument().createElement("component"); - component.setAttribute("id", embedder.getAttribute("id")); - component.setAttribute("class", embedderClassFrom(embedder)); - component.setAttribute("bundle", embedder.hasAttribute("bundle") ? embedder.getAttribute("bundle") : "model-integration"); - - String configDef = embedderConfigFrom(embedder); - if ( ! configDef.isEmpty()) { - Element config = component.getOwnerDocument().createElement("config"); - config.setAttribute("name", configDef); - for (Element child : XML.getChildren(embedder)) - addConfigValue(child, config, deployState.isHosted()); - component.appendChild(config); - } - - return component; - } - - /** Adds a config value from an embedder element into a regular config. */ - private static void addConfigValue(Element value, Element config, boolean hosted) { - if (value.hasAttribute("path")) { - addChild(value.getTagName() + "Url", "", config); - addChild(value.getTagName() + "Path", value.getAttribute("path"), config); - } - else if (value.hasAttribute("id") && hosted) { - addChild(value.getTagName() + "Url", modelIdToUrl(value.getAttribute("id")), config); - addChild(value.getTagName() + "Path", dummyPath, config); - } - else if (value.hasAttribute("url")) { - addChild(value.getTagName() + "Url", value.getAttribute("url"), config); - addChild(value.getTagName() + "Path", dummyPath, config); - } - else { - addChild(value.getTagName(), value.getTextContent(), config); - } - } - - private static void addChild(String name, String value, Element parent) { - Element element = parent.getOwnerDocument().createElement(name); - element.setTextContent(value); - parent.appendChild(element); - } - - private static String embedderConfigFrom(Element spec) { - String explicitDefinition = spec.getAttribute("def"); - if ( ! explicitDefinition.isEmpty()) return explicitDefinition; - - // Implicit from class name - return switch (embedderClassFrom(spec)) { - case "ai.vespa.embedding.BertBaseEmbedder" -> "embedding.bert-base-embedder"; - default -> ""; - }; - } - - private static String modelIdToUrl(String id) { - switch (id) { - case "test-model-id": - return "test-model-url"; - case "minilm-l6-v2": - return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"; - case "bert-base-uncased": - return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"; - } - throw new IllegalArgumentException("Unknown model id '" + id + "'"); - } - - private static String embedderClassFrom(Element spec) { - if (spec.hasAttribute("class")) { - return spec.getAttribute("class"); - } - if (spec.hasAttribute("id")) { - return spec.getAttribute("id"); - } - throw new IllegalArgumentException("An <embedder> element must have a 'class' or 'id' attribute"); - } - - -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java new file mode 100644 index 00000000000..0065a582145 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java @@ -0,0 +1,73 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.xml; + +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.text.XML; +import org.w3c.dom.Element; + +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Translates model references in component configs. + * + * @author lesters + * @author bratseth + */ +public class ModelConfigTransformer { + + private static final Map<String, String> providedModels = + Map.of("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"); + + // Until we have optional path parameters, use services.xml as it is guaranteed to exist + private final static String dummyPath = "services.xml"; + + /** + * Transforms the <embedder ...> element to component configuration. + * + * @param deployState the deploy state - as config generation can depend on context + * @param component the XML element containing the <embedder ...> + * @return a new XML element containting the <component ...> configuration + */ + public static Element transform(DeployState deployState, Element component) { + for (Element config : XML.getChildren(component, "config")) { + for (Element value : XML.getChildren(config)) + transformModelValue(value, config, deployState.isHosted()); + } + return component; + } + + /** Expans a model config value into regular config values. */ + private static void transformModelValue(Element value, Element config, boolean hosted) { + if (value.hasAttribute("path")) { + addChild(value.getTagName() + "Url", "", config); + addChild(value.getTagName() + "Path", value.getAttribute("path"), config); + config.removeChild(value); + } + else if (value.hasAttribute("id") && hosted) { + addChild(value.getTagName() + "Url", modelIdToUrl(value.getAttribute("id")), config); + addChild(value.getTagName() + "Path", dummyPath, config); + config.removeChild(value); + } + else if (value.hasAttribute("url")) { + addChild(value.getTagName() + "Url", value.getAttribute("url"), config); + addChild(value.getTagName() + "Path", dummyPath, config); + config.removeChild(value); + } + } + + private static void addChild(String name, String value, Element parent) { + Element element = parent.getOwnerDocument().createElement(name); + element.setTextContent(value); + parent.appendChild(element); + } + + private static String modelIdToUrl(String id) { + if ( ! providedModels.containsKey(id)) + throw new IllegalArgumentException("Unknown embedder model '" + id + "'. Available models are [" + + providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]"); + return providedModels.get(id); + } + +} diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index de44f9cc071..27f3b37b78b 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -65,15 +65,4 @@ ComponentDefinition = ComponentId & BundleSpec & GenericConfig* & - Component* & - Embedder* - -Embedder = element embedder { - attribute id { string } & - attribute class { xsd:Name | JavaId }? & - attribute bundle { xsd:Name }? & - attribute def { xsd:Name }? & - anyElement* -} - - + Component* diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 4cc55ad75d8..9012462d2eb 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -17,7 +17,6 @@ ContainerServices = DocumentApi? & Components* & Component* & - Embedder* & Handler* & Server* & Http? & @@ -31,8 +30,7 @@ ClientAuthorize = element client-authorize { empty } Components = element components { Include* & - Component* & - Embedder* + Component* } Include = element \include { diff --git a/config-model/src/main/resources/schema/docproc.rnc b/config-model/src/main/resources/schema/docproc.rnc index b4db09f2fb8..11f8e14fb2d 100644 --- a/config-model/src/main/resources/schema/docproc.rnc +++ b/config-model/src/main/resources/schema/docproc.rnc @@ -50,7 +50,6 @@ ClusterV3 = element cluster { GenericConfig* & SchemaMapping? & Component* & - Embedder* & Handler* & DocprocChainsV3? } diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 9a05337f954..88558ace4bf 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -4,19 +4,16 @@ <container version="1.0"> - <embedder id="test" class="ai.vespa.embedding.UndefinedEmbedder" bundle="dummy" def="test.dummy"> - <num>12</num> - <str>some text</str> - </embedder> + <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bindle="model-integration"> + <config name="embedding.bert-base-embedder"> + <!-- model specifics --> + <transformerModel id="minilm-l6-v2" url="application-url"/> + <tokenizerVocab path="files/vocab.txt"/> - <embedder id="transformer" class="ai.vespa.embedding.BertBaseEmbedder"> - <!-- model specifics --> - <transformerModel id="test-model-id" url="test-model-url"/> - <tokenizerVocab path="files/vocab.txt"/> - - <!-- tunable parameters: number of threads etc --> - <onnxIntraOpThreads>4</onnxIntraOpThreads> - </embedder> + <!-- tunable parameters: number of threads etc --> + <onnxIntraOpThreads>4</onnxIntraOpThreads> + </config> + </component> <nodes> <node hostalias="node1" /> diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def index ac5c79d2714..81fc88dbf01 100644 --- a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def +++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def @@ -1,12 +1,15 @@ package=ai.vespa.example.paragraph # Settings for wordpiece tokenizer -vocab path +vocabPath path +vocabUrl string # Transformer model settings modelPath path modelUrl string +myValue string + # Max length of token sequence model can handle transforerMaxTokens int default=128 diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml index ab2c1be9745..ea430f24e2f 100644 --- a/config-model/src/test/cfg/application/embed_generic/services.xml +++ b/config-model/src/test/cfg/application/embed_generic/services.xml @@ -4,13 +4,15 @@ <container version="1.0"> - <embedder id='transformer' - class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' - bundle='exampleEmbedder' - def='ai.vespa.example.paragraph.sentence-embedder'> - <model path="files/model.onnx" /> <!-- Embedder syntax for file path --> - <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path --> - </embedder> + <component id='transformer' + class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' + bundle='exampleEmbedder'> + <config name='ai.vespa.example.paragraph.sentence-embedder'> + <model id="minilm-l6-v2" url="application-url" /> + <vocab path="files/vocab.txt"/> + <myValue>foo</myValue> + </config> + </component> <nodes> <node hostalias='node1'/> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index d64e726eb6a..ffa7e52136f 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -28,130 +28,158 @@ import static org.junit.jupiter.api.Assertions.fail; public class EmbedderTestCase { - private static final String PREDEFINED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder"; - private static final String PREDEFINED_EMBEDDER_CONFIG = "embedding.bert-base-embedder"; + private static final String emptyPathFileName = "services.xml"; + private static final String BUNDLED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder"; + private static final String BUNDLED_EMBEDDER_CONFIG = "embedding.bert-base-embedder"; @Test - void testGenericEmbedConfig() throws IOException, SAXException { - String embedder = "<embedder id='test' class='ai.vespa.test' bundle='bundle' def='def.name'>" + - " <val>123</val>" + - "</embedder>"; - String component = "<component id='test' class='ai.vespa.test' bundle='bundle'>" + - " <config name='def.name'>" + - " <val>123</val>" + - " </config>" + - "</component>"; - assertTransform(embedder, component); - } - - @Test - void testPredefinedEmbedConfigSelfHosted() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='my_model_id' url='my-model-url' />" + - " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" + - "</embedder>"; - String component = "<component id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + + void testBundledEmbedder_selfhosted() throws IOException, SAXException { + String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='my_model_id' url='my-model-url' />" + + " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" + + " </config>" + + "</component>"; + String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + " <transformerModelUrl>my-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + + " <transformerModelPath>services.xml</transformerModelPath>" + " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + + " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, false); + assertTransform(input, component, false); } @Test - void testPathHasPrioritySelfHosted() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='my_model_id' url='my-model-url' path='files/model.onnx' />" + - " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' path='files/vocab.txt' />" + - "</embedder>"; - String component = "<component id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + + void testPathHasPriority_selfhosted() throws IOException, SAXException { + String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='my_model_id' url='my-model-url' path='files/model.onnx' />" + + " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' path='files/vocab.txt' />" + + " </config>" + + "</component>"; + String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + " <transformerModelUrl></transformerModelUrl>" + " <transformerModelPath>files/model.onnx</transformerModelPath>" + " <tokenizerVocabUrl></tokenizerVocabUrl>" + " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, false); + assertTransform(input, component, false); } @Test - void testPredefinedEmbedConfigCloud() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='test-model-id' />" + - " <tokenizerVocab id='test-model-id' />" + - "</embedder>"; - String component = "<component id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + - " <transformerModelUrl>test-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + + void testBundledEmbedder_hosted() throws IOException, SAXException { + String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='minilm-l6-v2' />" + + " <tokenizerVocab id='bert-base-uncased' />" + + " </config>" + + "</component>"; + String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModelUrl>https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx</transformerModelUrl>" + + " <transformerModelPath>services.xml</transformerModelPath>" + + " <tokenizerVocabUrl>https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt</tokenizerVocabUrl>" + + " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, true); + assertTransform(input, component, true); } @Test - void testCustomEmbedderWithPredefinedConfigCloud() throws IOException, SAXException { - String embedder = "<embedder id='test' class='ApplicationSpecificEmbedder' def='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + - " <transformerModel id='test-model-id' />" + - " <tokenizerVocab id='test-model-id' />" + - "</embedder>"; + void testApplicationEmbedderWithBundledConfig_hosted() throws IOException, SAXException { + String input = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='minilm-l6-v2' />" + + " <tokenizerVocab id='bert-base-uncased' />" + + " </config>" + + "</component>"; String component = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + - " <transformerModelUrl>test-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModelUrl>https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx</transformerModelUrl>" + + " <transformerModelPath>services.xml</transformerModelPath>" + + " <tokenizerVocabUrl>https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt</tokenizerVocabUrl>" + + " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, true); + assertTransform(input, component, true); } @Test - void testUnknownModelIdCloud() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='my_model_id' />" + - " <tokenizerVocab id='my_vocab_id' />" + - "</embedder>"; - assertTransformThrows(embedder, "Unknown model id 'my_model_id'", true); + void testUnknownModelId_hosted() throws IOException, SAXException { + String embedder = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='my_model_id' />" + + " <tokenizerVocab id='my_vocab_id' />" + + " </config>" + + "</component>"; + assertTransformThrows(embedder, + "Unknown embedder model 'my_model_id'. " + + "Available models are [bert-base-uncased, minilm-l6-v2]", + true); } @Test - void testApplicationWithEmbedConfig() throws Exception { - final String emptyPathFileName = "services.xml"; - + void testApplicationPackageWithEmbedder_selfhosted() throws Exception { Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); VespaModel model = loadModel(applicationDir, false); ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); - Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test")); - ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test")); - assertEquals("12", testConfig.getObject("num").getValue()); - assertEquals("some text", testConfig.getObject("str").getValue()); + Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); + assertEquals("application-url", config.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue()); + assertEquals("", config.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue()); + assertEquals("4", config.getObject("onnxIntraOpThreads").getValue()); + } + + @Test + void testApplicationPackageWithEmbedder_hosted() throws Exception { + Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); + VespaModel model = loadModel(applicationDir, true); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); - ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); - assertEquals("test-model-url", transformerConfig.getObject("transformerModelUrl").getValue()); - assertEquals(emptyPathFileName, transformerConfig.getObject("transformerModelPath").getValue()); - assertEquals("", transformerConfig.getObject("tokenizerVocabUrl").getValue()); - assertEquals("files/vocab.txt", transformerConfig.getObject("tokenizerVocabPath").getValue()); + ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); + assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + config.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue()); + assertEquals("", config.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue()); + assertEquals("4", config.getObject("onnxIntraOpThreads").getValue()); } @Test - void testApplicationWithGenericEmbedConfig() throws Exception { + void testApplicationPackageWithApplicationEmbedder_selfhosted() throws Exception { Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); VespaModel model = loadModel(applicationDir, false); ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); - assertEquals("files/vocab.txt", config.getObject("vocab").getValue()); - assertEquals("files/model.onnx", config.getObject("modelPath").getValue()); + assertEquals("application-url", config.getObject("modelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("modelPath").getValue()); + assertEquals("files/vocab.txt", config.getObject("vocabPath").getValue()); + assertEquals("foo", config.getObject("myValue").getValue()); + } + + @Test + void testApplicationPackageWithApplicationEmbedder_hosted() throws Exception { + Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); + VespaModel model = loadModel(applicationDir, true); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + + Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); + assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + config.getObject("modelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("modelPath").getValue()); + assertEquals("files/vocab.txt", config.getObject("vocabPath").getValue()); + assertEquals("foo", config.getObject("myValue").getValue()); } private VespaModel loadModel(Path path, boolean hosted) throws Exception { @@ -165,17 +193,16 @@ public class EmbedderTestCase { assertTransform(embedder, component, false); } - private void assertTransform(String embedder, String component, boolean hosted) throws IOException, SAXException { - Element emb = createElement(embedder); - Element cmp = createElement(component); - Element trans = EmbedderConfigTransformer.transform(createEmptyDeployState(hosted), emb); - assertSpec(cmp, trans); + private void assertTransform(String embedder, String expectedComponent, boolean hosted) throws IOException, SAXException { + assertSpec(createElement(expectedComponent), + ModelConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder))); } private void assertSpec(Element e1, Element e2) { assertEquals(e1.getTagName(), e2.getTagName()); assertAttributes(e1, e2); assertAttributes(e2, e1); + assertEquals(XML.getValue(e1).trim(), XML.getValue(e2).trim(), "Content of " + e1.getTagName() + "' is identical"); assertChildren(e1, e2); } @@ -200,7 +227,7 @@ public class EmbedderTestCase { private void assertTransformThrows(String embedder, String expectedMessage, boolean hosted) throws IOException, SAXException { try { - EmbedderConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder)); + ModelConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder)); fail("Expected exception was not thrown: " + expectedMessage); } catch (IllegalArgumentException e) { assertEquals(expectedMessage, e.getMessage()); diff --git a/config-model/src/test/schema-test-files/services.xml b/config-model/src/test/schema-test-files/services.xml index ffb28726d9a..b32849bb55f 100644 --- a/config-model/src/test/schema-test-files/services.xml +++ b/config-model/src/test/schema-test-files/services.xml @@ -196,7 +196,6 @@ <component id="injected-to-handler"> <config name="foo"/> </component> - <embedder id="transformer" class="ai.vespa.example.SomeEmbedder" bundle="myBundle" def="my.def-file"/> </handler> <server id="server-provider"> |