diff options
13 files changed, 433 insertions, 537 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 213c2ca5df0..fc8a542b81c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -184,7 +184,6 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { addConfiguredComponents(deployState, cluster, spec); addSecretStore(cluster, spec, deployState); - addEmbedderComponents(deployState, cluster, spec); addModelEvaluation(spec, cluster, context); addModelEvaluationBundles(cluster); @@ -352,19 +351,12 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { container.setProp("rotations", String.join(",", rotationsProperty)); } - private static void addEmbedderComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) { - for (Element node : XML.getChildren(spec, "embedder")) { - Element transformed = EmbedderConfigTransformer.transform(deployState, node); - cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, transformed)); - } - } - - private void addConfiguredComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) { - for (Element components : XML.getChildren(spec, "components")) { + private void addConfiguredComponents(DeployState deployState, ApplicationContainerCluster cluster, Element parent) { + for (Element components : XML.getChildren(parent, "components")) { addIncludes(components); addConfiguredComponents(deployState, cluster, components, "component"); } - addConfiguredComponents(deployState, cluster, spec, "component"); + addConfiguredComponents(deployState, cluster, parent, "component"); } protected void addStatusHandlers(ApplicationContainerCluster cluster, boolean isHostedVespa) { @@ -963,9 +955,10 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } private static void addConfiguredComponents(DeployState deployState, ContainerCluster<? extends Container> cluster, - Element spec, String componentName) { - for (Element node : XML.getChildren(spec, componentName)) { - cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, node)); + Element parent, String componentName) { + for (Element component : XML.getChildren(parent, componentName)) { + component = ModelConfigTransformer.transform(deployState, component); + cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, component)); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/EmbedderConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/EmbedderConfigTransformer.java deleted file mode 100644 index 82ce8070c29..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/EmbedderConfigTransformer.java +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.xml; - -import com.yahoo.config.model.deploy.DeployState; -import com.yahoo.text.XML; -import org.w3c.dom.Element; - -/** - * Translates config in services.xml of the form - * - * <embedder id="..." class="..." bundle="..." def="..."> - * <!-- options --> - * </embedder> - * - * to component configuration of the form - * - * <component id="..." class="..." bundle="..."> - * <config name=def> - * <!-- options --> - * </config> - * </component> - * - * with some added interpretations based on recognizing the class. - * - * @author lesters - * @author bratseth - */ -public class EmbedderConfigTransformer { - - // Until we have optional path parameters, use services.xml as it is guaranteed to exist - private final static String dummyPath = "services.xml"; - - /** - * Transforms the <embedder ...> element to component configuration. - * - * @param deployState the deploy state - as config generation can depend on context - * @param embedder the XML element containing the <embedder ...> - * @return a new XML element containting the <component ...> configuration - */ - public static Element transform(DeployState deployState, Element embedder) { - Element component = XML.getDocumentBuilder().newDocument().createElement("component"); - component.setAttribute("id", embedder.getAttribute("id")); - component.setAttribute("class", embedderClassFrom(embedder)); - component.setAttribute("bundle", embedder.hasAttribute("bundle") ? embedder.getAttribute("bundle") : "model-integration"); - - String configDef = embedderConfigFrom(embedder); - if ( ! configDef.isEmpty()) { - Element config = component.getOwnerDocument().createElement("config"); - config.setAttribute("name", configDef); - for (Element child : XML.getChildren(embedder)) - addConfigValue(child, config, deployState.isHosted()); - component.appendChild(config); - } - - return component; - } - - /** Adds a config value from an embedder element into a regular config. */ - private static void addConfigValue(Element value, Element config, boolean hosted) { - if (value.hasAttribute("path")) { - addChild(value.getTagName() + "Url", "", config); - addChild(value.getTagName() + "Path", value.getAttribute("path"), config); - } - else if (value.hasAttribute("id") && hosted) { - addChild(value.getTagName() + "Url", modelIdToUrl(value.getAttribute("id")), config); - addChild(value.getTagName() + "Path", dummyPath, config); - } - else if (value.hasAttribute("url")) { - addChild(value.getTagName() + "Url", value.getAttribute("url"), config); - addChild(value.getTagName() + "Path", dummyPath, config); - } - else { - addChild(value.getTagName(), value.getTextContent(), config); - } - } - - private static void addChild(String name, String value, Element parent) { - Element element = parent.getOwnerDocument().createElement(name); - element.setTextContent(value); - parent.appendChild(element); - } - - private static String embedderConfigFrom(Element spec) { - String explicitDefinition = spec.getAttribute("def"); - if ( ! explicitDefinition.isEmpty()) return explicitDefinition; - - // Implicit from class name - return switch (embedderClassFrom(spec)) { - case "ai.vespa.embedding.BertBaseEmbedder" -> "embedding.bert-base-embedder"; - default -> ""; - }; - } - - private static String modelIdToUrl(String id) { - switch (id) { - case "test-model-id": - return "test-model-url"; - case "minilm-l6-v2": - return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"; - case "bert-base-uncased": - return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"; - } - throw new IllegalArgumentException("Unknown model id '" + id + "'"); - } - - private static String embedderClassFrom(Element spec) { - if (spec.hasAttribute("class")) { - return spec.getAttribute("class"); - } - if (spec.hasAttribute("id")) { - return spec.getAttribute("id"); - } - throw new IllegalArgumentException("An <embedder> element must have a 'class' or 'id' attribute"); - } - - -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java new file mode 100644 index 00000000000..0065a582145 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java @@ -0,0 +1,73 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.xml; + +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.text.XML; +import org.w3c.dom.Element; + +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Translates model references in component configs. + * + * @author lesters + * @author bratseth + */ +public class ModelConfigTransformer { + + private static final Map<String, String> providedModels = + Map.of("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"); + + // Until we have optional path parameters, use services.xml as it is guaranteed to exist + private final static String dummyPath = "services.xml"; + + /** + * Transforms the <embedder ...> element to component configuration. + * + * @param deployState the deploy state - as config generation can depend on context + * @param component the XML element containing the <embedder ...> + * @return a new XML element containting the <component ...> configuration + */ + public static Element transform(DeployState deployState, Element component) { + for (Element config : XML.getChildren(component, "config")) { + for (Element value : XML.getChildren(config)) + transformModelValue(value, config, deployState.isHosted()); + } + return component; + } + + /** Expans a model config value into regular config values. */ + private static void transformModelValue(Element value, Element config, boolean hosted) { + if (value.hasAttribute("path")) { + addChild(value.getTagName() + "Url", "", config); + addChild(value.getTagName() + "Path", value.getAttribute("path"), config); + config.removeChild(value); + } + else if (value.hasAttribute("id") && hosted) { + addChild(value.getTagName() + "Url", modelIdToUrl(value.getAttribute("id")), config); + addChild(value.getTagName() + "Path", dummyPath, config); + config.removeChild(value); + } + else if (value.hasAttribute("url")) { + addChild(value.getTagName() + "Url", value.getAttribute("url"), config); + addChild(value.getTagName() + "Path", dummyPath, config); + config.removeChild(value); + } + } + + private static void addChild(String name, String value, Element parent) { + Element element = parent.getOwnerDocument().createElement(name); + element.setTextContent(value); + parent.appendChild(element); + } + + private static String modelIdToUrl(String id) { + if ( ! providedModels.containsKey(id)) + throw new IllegalArgumentException("Unknown embedder model '" + id + "'. Available models are [" + + providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]"); + return providedModels.get(id); + } + +} diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index de44f9cc071..27f3b37b78b 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -65,15 +65,4 @@ ComponentDefinition = ComponentId & BundleSpec & GenericConfig* & - Component* & - Embedder* - -Embedder = element embedder { - attribute id { string } & - attribute class { xsd:Name | JavaId }? & - attribute bundle { xsd:Name }? & - attribute def { xsd:Name }? & - anyElement* -} - - + Component* diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 4cc55ad75d8..9012462d2eb 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -17,7 +17,6 @@ ContainerServices = DocumentApi? & Components* & Component* & - Embedder* & Handler* & Server* & Http? & @@ -31,8 +30,7 @@ ClientAuthorize = element client-authorize { empty } Components = element components { Include* & - Component* & - Embedder* + Component* } Include = element \include { diff --git a/config-model/src/main/resources/schema/docproc.rnc b/config-model/src/main/resources/schema/docproc.rnc index b4db09f2fb8..11f8e14fb2d 100644 --- a/config-model/src/main/resources/schema/docproc.rnc +++ b/config-model/src/main/resources/schema/docproc.rnc @@ -50,7 +50,6 @@ ClusterV3 = element cluster { GenericConfig* & SchemaMapping? & Component* & - Embedder* & Handler* & DocprocChainsV3? } diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 9a05337f954..88558ace4bf 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -4,19 +4,16 @@ <container version="1.0"> - <embedder id="test" class="ai.vespa.embedding.UndefinedEmbedder" bundle="dummy" def="test.dummy"> - <num>12</num> - <str>some text</str> - </embedder> + <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bindle="model-integration"> + <config name="embedding.bert-base-embedder"> + <!-- model specifics --> + <transformerModel id="minilm-l6-v2" url="application-url"/> + <tokenizerVocab path="files/vocab.txt"/> - <embedder id="transformer" class="ai.vespa.embedding.BertBaseEmbedder"> - <!-- model specifics --> - <transformerModel id="test-model-id" url="test-model-url"/> - <tokenizerVocab path="files/vocab.txt"/> - - <!-- tunable parameters: number of threads etc --> - <onnxIntraOpThreads>4</onnxIntraOpThreads> - </embedder> + <!-- tunable parameters: number of threads etc --> + <onnxIntraOpThreads>4</onnxIntraOpThreads> + </config> + </component> <nodes> <node hostalias="node1" /> diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def index ac5c79d2714..81fc88dbf01 100644 --- a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def +++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def @@ -1,12 +1,15 @@ package=ai.vespa.example.paragraph # Settings for wordpiece tokenizer -vocab path +vocabPath path +vocabUrl string # Transformer model settings modelPath path modelUrl string +myValue string + # Max length of token sequence model can handle transforerMaxTokens int default=128 diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml index ab2c1be9745..ea430f24e2f 100644 --- a/config-model/src/test/cfg/application/embed_generic/services.xml +++ b/config-model/src/test/cfg/application/embed_generic/services.xml @@ -4,13 +4,15 @@ <container version="1.0"> - <embedder id='transformer' - class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' - bundle='exampleEmbedder' - def='ai.vespa.example.paragraph.sentence-embedder'> - <model path="files/model.onnx" /> <!-- Embedder syntax for file path --> - <vocab>files/vocab.txt</vocab> <!-- Generic config syntax for file path --> - </embedder> + <component id='transformer' + class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' + bundle='exampleEmbedder'> + <config name='ai.vespa.example.paragraph.sentence-embedder'> + <model id="minilm-l6-v2" url="application-url" /> + <vocab path="files/vocab.txt"/> + <myValue>foo</myValue> + </config> + </component> <nodes> <node hostalias='node1'/> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index d64e726eb6a..ffa7e52136f 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -28,130 +28,158 @@ import static org.junit.jupiter.api.Assertions.fail; public class EmbedderTestCase { - private static final String PREDEFINED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder"; - private static final String PREDEFINED_EMBEDDER_CONFIG = "embedding.bert-base-embedder"; + private static final String emptyPathFileName = "services.xml"; + private static final String BUNDLED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder"; + private static final String BUNDLED_EMBEDDER_CONFIG = "embedding.bert-base-embedder"; @Test - void testGenericEmbedConfig() throws IOException, SAXException { - String embedder = "<embedder id='test' class='ai.vespa.test' bundle='bundle' def='def.name'>" + - " <val>123</val>" + - "</embedder>"; - String component = "<component id='test' class='ai.vespa.test' bundle='bundle'>" + - " <config name='def.name'>" + - " <val>123</val>" + - " </config>" + - "</component>"; - assertTransform(embedder, component); - } - - @Test - void testPredefinedEmbedConfigSelfHosted() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='my_model_id' url='my-model-url' />" + - " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" + - "</embedder>"; - String component = "<component id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + + void testBundledEmbedder_selfhosted() throws IOException, SAXException { + String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='my_model_id' url='my-model-url' />" + + " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" + + " </config>" + + "</component>"; + String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + " <transformerModelUrl>my-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + + " <transformerModelPath>services.xml</transformerModelPath>" + " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + + " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, false); + assertTransform(input, component, false); } @Test - void testPathHasPrioritySelfHosted() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='my_model_id' url='my-model-url' path='files/model.onnx' />" + - " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' path='files/vocab.txt' />" + - "</embedder>"; - String component = "<component id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + + void testPathHasPriority_selfhosted() throws IOException, SAXException { + String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='my_model_id' url='my-model-url' path='files/model.onnx' />" + + " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' path='files/vocab.txt' />" + + " </config>" + + "</component>"; + String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + " <transformerModelUrl></transformerModelUrl>" + " <transformerModelPath>files/model.onnx</transformerModelPath>" + " <tokenizerVocabUrl></tokenizerVocabUrl>" + " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, false); + assertTransform(input, component, false); } @Test - void testPredefinedEmbedConfigCloud() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='test-model-id' />" + - " <tokenizerVocab id='test-model-id' />" + - "</embedder>"; - String component = "<component id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + - " <transformerModelUrl>test-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + + void testBundledEmbedder_hosted() throws IOException, SAXException { + String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='minilm-l6-v2' />" + + " <tokenizerVocab id='bert-base-uncased' />" + + " </config>" + + "</component>"; + String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModelUrl>https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx</transformerModelUrl>" + + " <transformerModelPath>services.xml</transformerModelPath>" + + " <tokenizerVocabUrl>https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt</tokenizerVocabUrl>" + + " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, true); + assertTransform(input, component, true); } @Test - void testCustomEmbedderWithPredefinedConfigCloud() throws IOException, SAXException { - String embedder = "<embedder id='test' class='ApplicationSpecificEmbedder' def='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + - " <transformerModel id='test-model-id' />" + - " <tokenizerVocab id='test-model-id' />" + - "</embedder>"; + void testApplicationEmbedderWithBundledConfig_hosted() throws IOException, SAXException { + String input = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='minilm-l6-v2' />" + + " <tokenizerVocab id='bert-base-uncased' />" + + " </config>" + + "</component>"; String component = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" + - " <config name='" + PREDEFINED_EMBEDDER_CONFIG + "'>" + - " <transformerModelUrl>test-model-url</transformerModelUrl>" + - " <transformerModelPath></transformerModelPath>" + - " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + - " <tokenizerVocabPath></tokenizerVocabPath>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModelUrl>https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx</transformerModelUrl>" + + " <transformerModelPath>services.xml</transformerModelPath>" + + " <tokenizerVocabUrl>https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt</tokenizerVocabUrl>" + + " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" + " </config>" + "</component>"; - assertTransform(embedder, component, true); + assertTransform(input, component, true); } @Test - void testUnknownModelIdCloud() throws IOException, SAXException { - String embedder = "<embedder id='test' class='" + PREDEFINED_EMBEDDER_CLASS + "'>" + - " <transformerModel id='my_model_id' />" + - " <tokenizerVocab id='my_vocab_id' />" + - "</embedder>"; - assertTransformThrows(embedder, "Unknown model id 'my_model_id'", true); + void testUnknownModelId_hosted() throws IOException, SAXException { + String embedder = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "'>" + + " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + + " <transformerModel id='my_model_id' />" + + " <tokenizerVocab id='my_vocab_id' />" + + " </config>" + + "</component>"; + assertTransformThrows(embedder, + "Unknown embedder model 'my_model_id'. " + + "Available models are [bert-base-uncased, minilm-l6-v2]", + true); } @Test - void testApplicationWithEmbedConfig() throws Exception { - final String emptyPathFileName = "services.xml"; - + void testApplicationPackageWithEmbedder_selfhosted() throws Exception { Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); VespaModel model = loadModel(applicationDir, false); ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); - Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test")); - ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test")); - assertEquals("12", testConfig.getObject("num").getValue()); - assertEquals("some text", testConfig.getObject("str").getValue()); + Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); + assertEquals("application-url", config.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue()); + assertEquals("", config.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue()); + assertEquals("4", config.getObject("onnxIntraOpThreads").getValue()); + } + + @Test + void testApplicationPackageWithEmbedder_hosted() throws Exception { + Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); + VespaModel model = loadModel(applicationDir, true); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); - ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); - assertEquals("test-model-url", transformerConfig.getObject("transformerModelUrl").getValue()); - assertEquals(emptyPathFileName, transformerConfig.getObject("transformerModelPath").getValue()); - assertEquals("", transformerConfig.getObject("tokenizerVocabUrl").getValue()); - assertEquals("files/vocab.txt", transformerConfig.getObject("tokenizerVocabPath").getValue()); + ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); + assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + config.getObject("transformerModelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue()); + assertEquals("", config.getObject("tokenizerVocabUrl").getValue()); + assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue()); + assertEquals("4", config.getObject("onnxIntraOpThreads").getValue()); } @Test - void testApplicationWithGenericEmbedConfig() throws Exception { + void testApplicationPackageWithApplicationEmbedder_selfhosted() throws Exception { Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); VespaModel model = loadModel(applicationDir, false); ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); - assertEquals("files/vocab.txt", config.getObject("vocab").getValue()); - assertEquals("files/model.onnx", config.getObject("modelPath").getValue()); + assertEquals("application-url", config.getObject("modelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("modelPath").getValue()); + assertEquals("files/vocab.txt", config.getObject("vocabPath").getValue()); + assertEquals("foo", config.getObject("myValue").getValue()); + } + + @Test + void testApplicationPackageWithApplicationEmbedder_hosted() throws Exception { + Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); + VespaModel model = loadModel(applicationDir, true); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + + Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); + assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + config.getObject("modelUrl").getValue()); + assertEquals(emptyPathFileName, config.getObject("modelPath").getValue()); + assertEquals("files/vocab.txt", config.getObject("vocabPath").getValue()); + assertEquals("foo", config.getObject("myValue").getValue()); } private VespaModel loadModel(Path path, boolean hosted) throws Exception { @@ -165,17 +193,16 @@ public class EmbedderTestCase { assertTransform(embedder, component, false); } - private void assertTransform(String embedder, String component, boolean hosted) throws IOException, SAXException { - Element emb = createElement(embedder); - Element cmp = createElement(component); - Element trans = EmbedderConfigTransformer.transform(createEmptyDeployState(hosted), emb); - assertSpec(cmp, trans); + private void assertTransform(String embedder, String expectedComponent, boolean hosted) throws IOException, SAXException { + assertSpec(createElement(expectedComponent), + ModelConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder))); } private void assertSpec(Element e1, Element e2) { assertEquals(e1.getTagName(), e2.getTagName()); assertAttributes(e1, e2); assertAttributes(e2, e1); + assertEquals(XML.getValue(e1).trim(), XML.getValue(e2).trim(), "Content of " + e1.getTagName() + "' is identical"); assertChildren(e1, e2); } @@ -200,7 +227,7 @@ public class EmbedderTestCase { private void assertTransformThrows(String embedder, String expectedMessage, boolean hosted) throws IOException, SAXException { try { - EmbedderConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder)); + ModelConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder)); fail("Expected exception was not thrown: " + expectedMessage); } catch (IllegalArgumentException e) { assertEquals(expectedMessage, e.getMessage()); diff --git a/config-model/src/test/schema-test-files/services.xml b/config-model/src/test/schema-test-files/services.xml index ffb28726d9a..b32849bb55f 100644 --- a/config-model/src/test/schema-test-files/services.xml +++ b/config-model/src/test/schema-test-files/services.xml @@ -196,7 +196,6 @@ <component id="injected-to-handler"> <config name="foo"/> </component> - <embedder id="transformer" class="ai.vespa.example.SomeEmbedder" bundle="myBundle" def="my.def-file"/> </handler> <server id="server-provider"> diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index fbf1203acdf..8714285acd8 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -3362,6 +3362,7 @@ "public static javax.xml.parsers.DocumentBuilder getDocumentBuilder(java.lang.String, java.lang.ClassLoader, boolean)", "public static java.util.List getChildren(org.w3c.dom.Element)", "public static java.util.List getChildren(org.w3c.dom.Element, java.lang.String)", + "public static java.util.Optional attribute(java.lang.String, org.w3c.dom.Element)", "public static java.lang.String getValue(org.w3c.dom.Element)", "public static org.w3c.dom.Element getChild(org.w3c.dom.Element, java.lang.String)", "public static java.lang.String getNodePath(org.w3c.dom.Node, java.lang.String)", diff --git a/vespajlib/src/main/java/com/yahoo/text/XML.java b/vespajlib/src/main/java/com/yahoo/text/XML.java index 6aa42773ac0..255e6a67429 100644 --- a/vespajlib/src/main/java/com/yahoo/text/XML.java +++ b/vespajlib/src/main/java/com/yahoo/text/XML.java @@ -18,6 +18,7 @@ import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; import java.util.List; +import java.util.Optional; /** * Static XML utility methods @@ -29,280 +30,39 @@ import java.util.List; */ public class XML { - /** - * The point of this weird class and the jumble of abstract methods is - * linking the scan for characters that must be quoted into the quoting - * table, and making it actual work to make them go out of sync again. - */ - private static abstract class LegalCharacters { - - // To quote http://www.w3.org/TR/REC-xml/ : - // Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | - // [#x10000-#x10FFFF] - final boolean isLegal(int codepoint, boolean escapeLow, int stripCodePoint, boolean isAttribute) { - if (codepoint == stripCodePoint) { - return removeCodePoint(); - } else if (codepoint < ' ') { - if (!escapeLow) { - return true; - } - switch (codepoint) { - case 0x09: - case 0x0a: - case 0x0d: - return true; - default: - return ctrlEscapeCodePoint(codepoint); - } - } else if (codepoint >= 0x20 && codepoint <= 0xd7ff) { - switch (codepoint) { - case '&': - return ampCodePoint(); - case '<': - return ltCodePoint(); - case '>': - return gtCodePoint(); - case '"': - return quotCodePoint(isAttribute); - default: - return true; - } - } else if ((codepoint >= 0xe000 && codepoint <= 0xfffd) - || (codepoint >= 0x10000 && codepoint <= 0x10ffff)) { - return true; - } else { - return filterCodePoint(codepoint); - - } - } - - private boolean quotCodePoint(boolean isAttribute) { - if (isAttribute) { - quoteQuot(); - return false; - } else { - return true; - } - } - - private boolean filterCodePoint(int codepoint) { - replace(codepoint); - return false; - } - - private boolean gtCodePoint() { - quoteGt(); - return false; - } - - private boolean ltCodePoint() { - quoteLt(); - return false; - } - - private boolean ampCodePoint() { - quoteAmp(); - return false; - } - - private boolean ctrlEscapeCodePoint(int codepoint) { - ctrlEscape(codepoint); - return false; - } - - private boolean removeCodePoint() { - remove(); - return false; - } - - protected abstract void quoteQuot(); - - protected abstract void quoteGt(); - - protected abstract void quoteLt(); - - protected abstract void quoteAmp(); - - protected abstract void remove(); - - protected abstract void ctrlEscape(int codepoint); - - protected abstract void replace(int codepoint); - } - - private static final class Quote extends LegalCharacters { - - char[] lastQuoted; - private static final char[] EMPTY = new char[0]; - private static final char[] REPLACEMENT_CHARACTER = "\ufffd".toCharArray(); - private static final char[] AMP = "&".toCharArray(); - private static final char[] LT = "<".toCharArray(); - private static final char[] GT = ">".toCharArray(); - private static final char[] QUOT = """.toCharArray(); - - @Override - protected void remove() { - lastQuoted = EMPTY; - } - - @Override - protected void replace(final int codepoint) { - lastQuoted = REPLACEMENT_CHARACTER; - } - - @Override - protected void quoteQuot() { - lastQuoted = QUOT; - } - - @Override - protected void quoteGt() { - lastQuoted = GT; - } - - @Override - protected void quoteLt() { - lastQuoted = LT; - } - - @Override - protected void quoteAmp() { - lastQuoted = AMP; - } - - @Override - protected void ctrlEscape(final int codepoint) { - lastQuoted = REPLACEMENT_CHARACTER; - } - } - - private static final class Scan extends LegalCharacters { - - @Override - protected void quoteQuot() { - } - - @Override - protected void quoteGt() { - } - - @Override - protected void quoteLt() { - } - - @Override - protected void quoteAmp() { - } - - @Override - protected void remove() { - } - - @Override - protected void ctrlEscape(final int codepoint) { - } - - @Override - protected void replace(final int codepoint) { - } - } - private static final Scan scanner = new Scan(); - /** - * Replaces the characters that need to be escaped with their corresponding - * character entities. - * - * @param s1 - * String possibly containing characters that need to be escaped - * in XML - * - * @return Returns the input string with special characters that need to be - * escaped replaced by character entities. - */ - public static String xmlEscape(String s1) { - return xmlEscape(s1, true, true, null, -1); + /** Replaces the characters that need to be escaped with their corresponding character entities. */ + public static String xmlEscape(String string) { + return xmlEscape(string, true, true, null, -1); } - /** - * Replaces the characters that need to be escaped with their corresponding - * character entities. - * - * @param s1 - * String possibly containing characters that need to be escaped - * in XML - * @param isAttribute - * Is the input string to be used as an attribute? - * - * @return Returns the input string with special characters that need to be - * escaped replaced by character entities - */ - public static String xmlEscape(String s1, boolean isAttribute) { - return xmlEscape(s1, isAttribute, true, null, -1); + /** Replaces the characters that need to be escaped with their corresponding character entities. */ + public static String xmlEscape(String string, boolean isAttribute) { + return xmlEscape(string, isAttribute, true, null, -1); } - /** - * Replaces the characters that need to be escaped with their corresponding - * character entities. - * - * @param s1 - * String possibly containing characters that need to be escaped - * in XML - * @param isAttribute - * Is the input string to be used as an attribute? - * - * - * @param stripCharacter - * any occurrence of this character is removed from the string - * - * @return Returns the input string with special characters that need to be - * escaped replaced by character entities - */ - public static String xmlEscape(String s1, boolean isAttribute, char stripCharacter) { - return xmlEscape(s1, isAttribute, true, null, (int) stripCharacter); + /** Replaces the characters that need to be escaped with their corresponding character entities. */ + public static String xmlEscape(String string, boolean isAttribute, char stripCharacter) { + return xmlEscape(string, isAttribute, true, null, (int) stripCharacter); } - /** - * Replaces the characters that need to be escaped with their corresponding - * character entities. - * - * @param s1 - * String possibly containing characters that need to be escaped - * in XML - * @param isAttribute - * Is the input string to be used as an attribute? - * - * @param escapeLowAscii - * Should ascii characters below 32 be escaped as well - * - * @return Returns the input string with special characters that need to be - * escaped replaced by character entities - */ - public static String xmlEscape(String s1, boolean isAttribute, boolean escapeLowAscii) { - return xmlEscape(s1, isAttribute, escapeLowAscii, null, -1); + /** Replaces the characters that need to be escaped with their corresponding character entities. */ + public static String xmlEscape(String string, boolean isAttribute, boolean escapeLowAscii) { + return xmlEscape(string, isAttribute, escapeLowAscii, null, -1); } /** - * Replaces the characters that need to be escaped with their corresponding - * character entities. - * - * @param s1 - * String possibly containing characters that need to be escaped - * in XML - * @param isAttribute - * Is the input string to be used as an attribute? - * - * @param escapeLowAscii - * Should ascii characters below 32 be escaped as well + * Replaces the characters that need to be escaped with their corresponding character entities. * - * @param stripCharacter - * any occurrence of this character is removed from the string - * - * @return Returns the input string with special characters that need to be - * escaped replaced by character entities + * @param string the string possibly containing characters that need to be escaped in XML + * @param isAttribute whether the input string to be used as an attribute + * @param escapeLowAscii whether ascii characters below 32 should be escaped as well + * @param stripCharacter any occurrence of this character is removed from the string + * @return the input string with special characters that need to be escaped replaced by character entities */ - public static String xmlEscape(String s1, boolean isAttribute, boolean escapeLowAscii, char stripCharacter) { - return xmlEscape(s1, isAttribute, escapeLowAscii, null, (int) stripCharacter); + public static String xmlEscape(String string, boolean isAttribute, boolean escapeLowAscii, char stripCharacter) { + return xmlEscape(string, isAttribute, escapeLowAscii, null, (int) stripCharacter); } /** @@ -315,7 +75,6 @@ public class XML { * <li>double quotes (") if isAttribute is <code>true</code> * </ul> * with character entities. - * */ public static String xmlEscape(String string, boolean isAttribute, StringBuilder buffer) { return xmlEscape(string, isAttribute, true, buffer, -1); @@ -332,7 +91,6 @@ public class XML { * <li>double quotes (") if isAttribute is <code>true</code> * </ul> * with character entities. - * */ public static String xmlEscape(String string, boolean isAttribute, boolean escapeLowAscii, StringBuilder buffer) { return xmlEscape(string, isAttribute, escapeLowAscii, buffer, -1); @@ -438,15 +196,13 @@ public class XML { } } - /** - * Returns the Document of the string XML payload - */ + /** Returns the Document of the string XML payload. */ public static Document getDocument(String xmlString) { return getDocument(new StringReader(xmlString)); } /** - * Creates a new XML DocumentBuilder + * Creates a new XML DocumentBuilder. * * @return a DocumentBuilder * @throws RuntimeException if we fail to create one @@ -456,7 +212,7 @@ public class XML { } /** - * Creates a new XML DocumentBuilder + * Creates a new XML DocumentBuilder. * * @param implementation which jaxp implementation should be used * @param classLoader which class loader should be used when getting a new DocumentBuilder @@ -468,7 +224,7 @@ public class XML { } /** - * Creates a new XML DocumentBuilder + * Creates a new XML DocumentBuilder. * * @return a DocumentBuilder * @throws RuntimeException if we fail to create one @@ -479,7 +235,7 @@ public class XML { } /** - * Creates a new XML DocumentBuilder + * Creates a new XML DocumentBuilder. * * @param implementation which jaxp implementation should be used * @param classLoader which class loader should be used when getting a new DocumentBuilder @@ -508,7 +264,7 @@ public class XML { } /** - * Returns the child Element objects from a w3c dom spec + * Returns the child Element objects from a w3c dom spec. * * @return List of elements. Empty list (never null) if none found or if the given element is null */ @@ -554,18 +310,21 @@ public class XML { return ret; } + /** Returns the given attribute name from element, or empty if the element does not have it. */ + public static Optional<String> attribute(String name, Element element) { + if ( ! element.hasAttribute(name)) return Optional.empty(); + return Optional.of(element.getAttribute(name)); + } + /** * Gets the string contents of the given Element. Returns "", never null if * the element is null, or has no content */ public static String getValue(Element e) { - if (e == null) { - return ""; - } + if (e == null) return ""; Node child = e.getFirstChild(); - if (child == null) { - return ""; - } + if (child == null) return ""; + if (child.getNodeValue() == null) return ""; return child.getNodeValue(); } @@ -575,14 +334,11 @@ public class XML { } /** - * Returns the path to the given xml node, where each node name is separated - * by the given separator string. + * Returns the path to the given xml node, where each node name is separated by the given separator string. * - * @param n - * The xml node to find path to - * @param sep - * The separator string - * @return The path to the xml node as a String + * @param n the xml node to find path to + * @param sep the separator string + * @return the path to the xml node as a String */ public static String getNodePath(Node n, String sep) { if (n == null) { @@ -657,8 +413,7 @@ public class XML { * 1.1 (Second Edition)</a>. This does not check against reserved names, it * only checks the set of characters used. * - * @param possibleName - * a possibly valid XML name + * @param possibleName a possibly valid XML name * @return true if the name may be used as an XML tag or attribute name */ public static boolean isName(CharSequence possibleName) { @@ -683,4 +438,181 @@ public class XML { return valid; } + /** + * The point of this weird class and the jumble of abstract methods is + * linking the scan for characters that must be quoted into the quoting + * table, and making it actual work to make them go out of sync again. + */ + private static abstract class LegalCharacters { + // To quote http://www.w3.org/TR/REC-xml/ : + // Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | + // [#x10000-#x10FFFF] + final boolean isLegal(int codepoint, boolean escapeLow, int stripCodePoint, boolean isAttribute) { + if (codepoint == stripCodePoint) { + return removeCodePoint(); + } else if (codepoint < ' ') { + if (!escapeLow) { + return true; + } + switch (codepoint) { + case 0x09: + case 0x0a: + case 0x0d: + return true; + default: + return ctrlEscapeCodePoint(codepoint); + } + } else if (codepoint >= 0x20 && codepoint <= 0xd7ff) { + switch (codepoint) { + case '&': + return ampCodePoint(); + case '<': + return ltCodePoint(); + case '>': + return gtCodePoint(); + case '"': + return quotCodePoint(isAttribute); + default: + return true; + } + } else if ((codepoint >= 0xe000 && codepoint <= 0xfffd) + || (codepoint >= 0x10000 && codepoint <= 0x10ffff)) { + return true; + } else { + return filterCodePoint(codepoint); + + } + } + + private boolean quotCodePoint(boolean isAttribute) { + if (isAttribute) { + quoteQuot(); + return false; + } else { + return true; + } + } + + private boolean filterCodePoint(int codepoint) { + replace(codepoint); + return false; + } + + private boolean gtCodePoint() { + quoteGt(); + return false; + } + + private boolean ltCodePoint() { + quoteLt(); + return false; + } + + private boolean ampCodePoint() { + quoteAmp(); + return false; + } + + private boolean ctrlEscapeCodePoint(int codepoint) { + ctrlEscape(codepoint); + return false; + } + + private boolean removeCodePoint() { + remove(); + return false; + } + + protected abstract void quoteQuot(); + + protected abstract void quoteGt(); + + protected abstract void quoteLt(); + + protected abstract void quoteAmp(); + + protected abstract void remove(); + + protected abstract void ctrlEscape(int codepoint); + + protected abstract void replace(int codepoint); + } + + private static final class Quote extends LegalCharacters { + + char[] lastQuoted; + private static final char[] EMPTY = new char[0]; + private static final char[] REPLACEMENT_CHARACTER = "\ufffd".toCharArray(); + private static final char[] AMP = "&".toCharArray(); + private static final char[] LT = "<".toCharArray(); + private static final char[] GT = ">".toCharArray(); + private static final char[] QUOT = """.toCharArray(); + + @Override + protected void remove() { + lastQuoted = EMPTY; + } + + @Override + protected void replace(final int codepoint) { + lastQuoted = REPLACEMENT_CHARACTER; + } + + @Override + protected void quoteQuot() { + lastQuoted = QUOT; + } + + @Override + protected void quoteGt() { + lastQuoted = GT; + } + + @Override + protected void quoteLt() { + lastQuoted = LT; + } + + @Override + protected void quoteAmp() { + lastQuoted = AMP; + } + + @Override + protected void ctrlEscape(final int codepoint) { + lastQuoted = REPLACEMENT_CHARACTER; + } + } + + private static final class Scan extends LegalCharacters { + + @Override + protected void quoteQuot() { + } + + @Override + protected void quoteGt() { + } + + @Override + protected void quoteLt() { + } + + @Override + protected void quoteAmp() { + } + + @Override + protected void remove() { + } + + @Override + protected void ctrlEscape(final int codepoint) { + } + + @Override + protected void replace(final int codepoint) { + } + } + } |