diff options
16 files changed, 735 insertions, 11 deletions
diff --git a/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java b/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java index 307509f0452..c5522c4c96a 100644 --- a/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java +++ b/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java @@ -160,7 +160,9 @@ class AnalyzeClassVisitor extends ClassVisitor implements ImportCollector { if (ExportPackage.class.getName().equals(Type.getType(desc).getClassName())) { return visitExportPackage(); } else { - addImportWithTypeDesc(desc); + if (visible) { + addImportWithTypeDesc(desc); + } return Analyze.visitAnnotationDefault(this); } } diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java index 6b9b0845328..1c061ca49a2 100644 --- a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java +++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java @@ -3,8 +3,10 @@ package com.yahoo.container.plugin.classanalysis; import com.yahoo.container.plugin.classanalysis.sampleclasses.Base; import com.yahoo.container.plugin.classanalysis.sampleclasses.ClassAnnotation; +import com.yahoo.container.plugin.classanalysis.sampleclasses.InvisibleAnnotation; import com.yahoo.container.plugin.classanalysis.sampleclasses.Derived; import com.yahoo.container.plugin.classanalysis.sampleclasses.DummyAnnotation; +import com.yahoo.container.plugin.classanalysis.sampleclasses.InvisibleDummyAnnotation; import com.yahoo.container.plugin.classanalysis.sampleclasses.Fields; import com.yahoo.container.plugin.classanalysis.sampleclasses.Interface1; import com.yahoo.container.plugin.classanalysis.sampleclasses.Interface2; @@ -109,6 +111,11 @@ public class AnalyzeClassTest { } @Test + public void invisible_annotation_not_included() { + assertFalse(analyzeClass(InvisibleAnnotation.class).getReferencedClasses().contains(name(InvisibleDummyAnnotation.class))); + } + + @Test public void method_annotation_is_included() { assertTrue(analyzeClass(MethodAnnotation.class).getReferencedClasses().contains(name(DummyAnnotation.class))); } diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java index a07b2917bd4..42335da3d62 100644 --- a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java +++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java @@ -1,9 +1,13 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.plugin.classanalysis.sampleclasses; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + /** * Input for class analysis tests. * @author Tony Vaagenes */ +@Retention(RetentionPolicy.RUNTIME) public @interface DummyAnnotation { } diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java new file mode 100644 index 00000000000..ced7c3305b0 --- /dev/null +++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java @@ -0,0 +1,10 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.plugin.classanalysis.sampleclasses; + +/** + * Input for class analysis tests.* + * @author arnej + */ +@InvisibleDummyAnnotation +public class InvisibleAnnotation { +} diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java new file mode 100644 index 00000000000..b3cb75df354 --- /dev/null +++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java @@ -0,0 +1,13 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.plugin.classanalysis.sampleclasses; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * Input for class analysis tests. + * @author arnej + */ +@Retention(RetentionPolicy.CLASS) +public @interface InvisibleDummyAnnotation { +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 0bd93c6d0df..1121a90693b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -86,6 +86,7 @@ import com.yahoo.vespa.model.container.search.GUIHandler; import com.yahoo.vespa.model.container.search.PageTemplates; import com.yahoo.vespa.model.container.search.searchchain.SearchChains; import com.yahoo.vespa.model.container.xml.document.DocumentFactoryBuilder; +import com.yahoo.vespa.model.container.xml.embedder.EmbedderConfig; import com.yahoo.vespa.model.content.StorageGroup; import org.w3c.dom.Element; import org.w3c.dom.Node; @@ -197,9 +198,11 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { private void addClusterContent(ApplicationContainerCluster cluster, Element spec, ConfigModelContext context) { DeployState deployState = context.getDeployState(); DocumentFactoryBuilder.buildDocumentFactories(cluster, spec); + addConfiguredComponents(deployState, cluster, spec); addSecretStore(cluster, spec, deployState); + addEmbedderComponents(deployState, cluster, spec); addModelEvaluation(spec, cluster, context); addModelEvaluationBundles(cluster); @@ -382,6 +385,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } } + private static void addEmbedderComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) { + for (Element node : XML.getChildren(spec, "embedder")) { + Element transformed = EmbedderConfig.transform(deployState, node); + cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, transformed)); + } + } + private void addConfiguredComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) { for (Element components : XML.getChildren(spec, "components")) { addIncludes(components); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java new file mode 100644 index 00000000000..a2286647cdd --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java @@ -0,0 +1,79 @@ +package com.yahoo.vespa.model.container.xml.embedder; + +import com.yahoo.config.model.deploy.DeployState; +import org.w3c.dom.Element; +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; + +/** + * Translates config in services.xml of the form + * + * <embedder id="..." class="..." bundle="..." def="..."> + * <!-- options --> + * </embedder> + * + * to component configuration of the form + * + * <component id="..." class="..." bundle="..."> + * <config name=def> + * <!-- options --> + * </config> + * </component> + * + * with some added interpretations based on recognizing the class. + * + * @author lesters + */ +public class EmbedderConfig { + + static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) { + String classId = getEmbedderClass(spec); + switch (classId) { + case "ai.vespa.embedding.BertBaseEmbedder": return new EmbedderConfigBertBaseTransformer(spec, hosted); + } + return new EmbedderConfigTransformer(spec, hosted); + } + + static String modelIdToUrl(String id) { + switch (id) { + case "test-model-id": + return "test-model-url"; + case "minilm-l6-v2": + return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"; + case "bert-base-uncased": + return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"; + } + throw new IllegalArgumentException("Unknown model id: '" + id + "'"); + } + + /** + * Transforms the <embedder ...> element to component configuration. + * + * @param deployState the deploy state - as config generation can depend on context + * @param embedderSpec the XML element containing the <embedder ...> + * @return a new XML element containting the <component ...> configuration + */ + public static Element transform(DeployState deployState, Element embedderSpec) { + EmbedderConfigTransformer transformer = getEmbedderTransformer(embedderSpec, deployState.isHosted()); + NodeList children = embedderSpec.getChildNodes(); + for (int i = 0; i < children.getLength(); i++) { + Node child = children.item(i); + if (child instanceof Element) { + transformer.addOption((Element) child); + } + } + return transformer.createComponentConfig(deployState); + } + + private static String getEmbedderClass(Element spec) { + if (spec.hasAttribute("class")) { + return spec.getAttribute("class"); + } + if (spec.hasAttribute("id")) { + return spec.getAttribute("id"); + } + throw new IllegalArgumentException("Embedder specification does not have a required class attribute"); + } + + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java new file mode 100644 index 00000000000..9431926d088 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java @@ -0,0 +1,40 @@ +package com.yahoo.vespa.model.container.xml.embedder; + +import org.w3c.dom.Element; + +import java.util.Map; + +/** + * Transforms embedding configuration to component configuration for the + * BertBaseEmbedder using embedder.bert-base-embedder.def + * + * @author lesters + */ +public class EmbedderConfigBertBaseTransformer extends EmbedderConfigTransformer { + + private static final String BUNDLE = "model-integration"; + private static final String DEF = "embedding.bert-base-embedder"; + + public EmbedderConfigBertBaseTransformer(Element spec, boolean hosted) { + super(spec, hosted, BUNDLE, DEF); + + EmbedderOption.Builder modelOption = new EmbedderOption.Builder() + .name("model") + .required(true) + .optionTransformer(new EmbedderOption.ModelOptionTransformer("transformerModelPath", "transformerModelUrl")); + EmbedderOption.Builder vocabOption = new EmbedderOption.Builder() + .name("vocab") + .required(true) + .optionTransformer(new EmbedderOption.ModelOptionTransformer("tokenizerVocabPath", "tokenizerVocabUrl")); + + // Defaults + if (hosted) { + modelOption.attributes(Map.of("id", "minilm-l6-v2")).value(""); + vocabOption.attributes(Map.of("id", "bert-base-uncased")).value(""); + } + + addOption(modelOption.build()); + addOption(vocabOption.build()); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java new file mode 100644 index 00000000000..30327fdc8af --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java @@ -0,0 +1,102 @@ +package com.yahoo.vespa.model.container.xml.embedder; + +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.text.XML; +import org.w3c.dom.Document; +import org.w3c.dom.Element; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + + +/** + * A specific embedder to component configuration transformer. + * + * @author lesters + */ +public class EmbedderConfigTransformer { + + private final Document doc = XML.getDocumentBuilder().newDocument(); + + private final String id; + private final String className; + private final String bundle; + private final String def; + private final Map<String, EmbedderOption> options = new HashMap<>(); + + public EmbedderConfigTransformer(Element spec, boolean hosted) { + this(spec, hosted, null, null); + } + + public EmbedderConfigTransformer(Element spec, boolean hosted, String defaultBundle, String defaultDef) { + id = spec.getAttribute("id"); + className = spec.hasAttribute("class") ? spec.getAttribute("class") : id; + bundle = spec.hasAttribute("bundle") ? spec.getAttribute("bundle") : defaultBundle; + def = spec.hasAttribute("def") ? spec.getAttribute("def") : defaultDef; + + if (className == null || className.length() == 0) { + throw new IllegalArgumentException("Embedder class is empty"); + } + if (this.bundle == null || this.bundle.length() == 0) { + throw new IllegalArgumentException("Embedder configuration requires a bundle name"); + } + if (this.def == null || this.def.length() == 0) { + throw new IllegalArgumentException("Embedder configuration requires a config definition name"); + } + } + + Element createComponentConfig(DeployState deployState) { + checkRequiredOptions(); + + Element component = doc.createElement("component"); + component.setAttribute("id", id); + component.setAttribute("class", className); + component.setAttribute("bundle", bundle); + + if (options.size() > 0) { + Element config = doc.createElement("config"); + config.setAttribute("name", def); + for (Map.Entry<String, EmbedderOption> entry : options.entrySet()) { + entry.getValue().toElement(deployState, config); + } + component.appendChild(config); + } + + return component; + } + + // TODO: support nested options + void addOption(Element elem) { + String name = elem.getTagName(); + + EmbedderOption.Builder builder = new EmbedderOption.Builder(); + builder.name(name); + builder.value(elem.getTextContent()); + builder.attributes(elem); + + if (options.containsKey(name)) { + builder.required(options.get(name).required()); + builder.optionTransformer(options.get(name).optionTransformer()); + } + options.put(name, builder.build()); + } + + void addOption(EmbedderOption option) { + options.put(option.name(), option); + } + + private void checkRequiredOptions() { + List<String> missingOptions = new ArrayList<>(); + for (EmbedderOption option : options.values()) { + if ( ! option.isSet()) { + missingOptions.add(option.name()); + } + } + if (missingOptions.size() > 0) { + throw new IllegalArgumentException("Embedder '" + className + "' requires options for " + missingOptions); + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java new file mode 100644 index 00000000000..206745887d1 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java @@ -0,0 +1,184 @@ +package com.yahoo.vespa.model.container.xml.embedder; + +import com.yahoo.config.model.deploy.DeployState; +import org.w3c.dom.Element; +import org.w3c.dom.NamedNodeMap; + +import java.util.HashMap; +import java.util.Map; + + +/** + * Holds options for embedder configuration. This includes code for handling special + * options such as model specifiers. + * + * @author lesters + */ +public class EmbedderOption { + + public static final OptionTransformer defaultOptionTransformer = new OptionTransformer(); + + private final String name; + private final boolean required; + private final String value; + private final Map<String, String> attributes; + private final OptionTransformer optionTransformer; + private final boolean set; + + private EmbedderOption(Builder builder) { + this.name = builder.name; + this.required = builder.required; + this.value = builder.value; + this.attributes = builder.attributes; + this.optionTransformer = builder.optionTransformer; + this.set = builder.set; + } + + public void toElement(DeployState deployState, Element parent) { + optionTransformer.transform(deployState, parent, this); + } + + public String name() { + return name; + } + + public String value() { + return value; + } + + public boolean required() { + return required; + } + + public OptionTransformer optionTransformer() { + return optionTransformer; + } + + public boolean isSet() { + return set; + } + + /** + * Basic option transformer. No special handling of options. + */ + public static class OptionTransformer { + public void transform(DeployState deployState, Element parent, EmbedderOption option) { + createElement(parent, option.name(), option.value()); + } + + public static Element createElement(Element parent, String name, String value) { + Element element = parent.getOwnerDocument().createElement(name); + element.setTextContent(value); + parent.appendChild(element); + return element; + } + } + + /** + * Transforms model options of type <x id="..." url="..." path="..." /> to the + * required fields in the config definition. + */ + public static class ModelOptionTransformer extends OptionTransformer { + + private final String pathField; + private final String urlField; + + public ModelOptionTransformer(String pathField, String urlField) { + super(); + this.pathField = pathField; + this.urlField = urlField; + } + + @Override + public void transform(DeployState deployState, Element parent, EmbedderOption option) { + String id = option.attributes.get("id"); + String url = option.attributes.get("url"); + String path = option.attributes.get("path"); + + // Always use path if it is set + if (path != null && path.length() > 0) { + createElement(parent, pathField, path); + createElement(parent, urlField, ""); + return; + } + + // Only use the id if we're on cloud + if (deployState.isHosted() && id != null && id.length() > 0) { + createElement(parent, urlField, EmbedderConfig.modelIdToUrl(id)); + createElement(parent, pathField, createDummyPath(deployState)); + return; + } + + // Otherwise, use url + if (url != null && url.length() > 0) { + createElement(parent, urlField, url); + createElement(parent, pathField, createDummyPath(deployState)); + return; + } + + if ( ! deployState.isHosted() && id != null && id.length() > 0) { + throw new IllegalArgumentException("Model option 'id' is not valid here"); + } + throw new IllegalArgumentException("Model option requires either a 'path' or a 'url' attribute"); + } + + private String createDummyPath(DeployState deployState) { + // For now, until we have optional config parameters, return services.xml as it is guaranteed to exist + return "services.xml"; + } + + } + + public static class Builder { + private String name = ""; + private boolean required = false; + private String value = ""; + private Map<String, String> attributes = Map.of(); + private OptionTransformer optionTransformer = defaultOptionTransformer; + private boolean set = false; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder required(boolean required) { + this.required = required; + return this; + } + + public Builder value(String value) { + this.value = value; + this.set = true; + return this; + } + + public Builder attributes(Map<String, String> attributes) { + this.attributes = attributes; + return this; + } + + public Builder attributes(Element element) { + NamedNodeMap map = element.getAttributes(); + if (map.getLength() > 0) { + this.attributes = new HashMap<>(map.getLength()); + for (int i = 0; i < map.getLength(); ++i) { + String attr = map.item(i).getNodeName(); + attributes.put(attr, element.getAttribute(attr)); + } + } + return this; + } + + public Builder optionTransformer(OptionTransformer optionTransformer) { + this.optionTransformer = optionTransformer; + return this; + } + + public EmbedderOption build() { + return new EmbedderOption(this); + } + + } + +} diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index e8ff1721397..3fdbff84f6d 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -19,6 +19,7 @@ ContainerServices = DocumentApi? & Components* & Component* & + Embedder* & Handler* & Client* & Server* & @@ -103,6 +104,14 @@ ZooKeeper = element zookeeper { empty } +Embedder = element embedder { + attribute id { string }? & + attribute class { xsd:Name | JavaId }? & + attribute bundle { xsd:Name }? & + attribute def { xsd:Name }? & + anyElement* +} + ModelEvaluation = element model-evaluation { element onnx { element models { diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml new file mode 100644 index 00000000000..f319d875ed9 --- /dev/null +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -0,0 +1,26 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container version="1.0"> + + <embedder id="test" class="ai.vespa.embedding.UndefinedEmbedder" bundle="dummy" def="test.dummy"> + <num>12</num> + <str>some text</str> + </embedder> + + <embedder id="transformer" class="ai.vespa.embedding.BertBaseEmbedder"> + <!-- model specifics --> + <model id="test-model-id" url="test-model-url"/> + <vocab path="files/vocab.txt"/> + + <!-- tunable parameters: number of threads etc --> + <onnxIntraOpThreads>4</onnxIntraOpThreads> + </embedder> + + <nodes> + <node hostalias="node1" /> + </nodes> + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java new file mode 100644 index 00000000000..0dae86473c8 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -0,0 +1,223 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.xml; + +import com.yahoo.component.ComponentId; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.path.Path; +import com.yahoo.text.XML; +import com.yahoo.vespa.config.ConfigDefinitionKey; +import com.yahoo.vespa.config.ConfigPayloadBuilder; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import com.yahoo.vespa.model.container.component.Component; +import com.yahoo.vespa.model.container.xml.embedder.EmbedderConfig; +import org.junit.Test; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.w3c.dom.NamedNodeMap; +import org.xml.sax.SAXException; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class EmbedderTestCase { + + private static final String PREDEFINED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder"; + private static final String PREDEFINED_EMBEDDER_CONFIG = "embedding.bert-base-embedder"; + + @Test + public void testGenericEmbedConfig() throws IOException, SAXException { + String embedder = "<embedder id=\"test\" class=\"ai.vespa.test\" bundle=\"bundle\" def=\"def.name\">" + + " <val>123</val>" + + "</embedder>"; + String component = "<component id=\"test\" class=\"ai.vespa.test\" bundle=\"bundle\">" + + " <config name=\"def.name\">" + + " <val>123</val>" + + " </config>" + + "</component>"; + assertTransform(embedder, component); + } + + @Test + public void testGenericEmbedConfigRequiresBundleAndDef() throws IOException, SAXException { + assertTransformThrows("<embedder id=\"test\" class=\"ai.vespa.test\"></embedder>", + "Embedder configuration requires a bundle name"); + assertTransformThrows("<embedder id=\"test\" class=\"ai.vespa.test\" bundle=\"bundle\"></embedder>", + "Embedder configuration requires a config definition name"); + } + + @Test + public void testPredefinedEmbedConfigSelfHosted() throws IOException, SAXException { + assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\"></embedder>", + "Embedder '" + PREDEFINED_EMBEDDER_CLASS + "' requires options for [vocab, model]"); + assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model />" + + " <vocab />" + + "</embedder>", + "Model option requires either a 'path' or a 'url' attribute"); + assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"my_model_id\" />" + + " <vocab id=\"my_vocab_id\" />" + + "</embedder>", + "Model option 'id' is not valid here"); + + String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"my_model_id\" url=\"my-model-url\" />" + + " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" />" + + "</embedder>"; + String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>my-model-url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; + assertTransform(embedder, component, false); + + // Path has priority: + embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"my_model_id\" url=\"my-model-url\" path=\"files/model.onnx\" />" + + " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" path=\"files/vocab.txt\" />" + + "</embedder>"; + component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" + + " <tokenizerVocabUrl></tokenizerVocabUrl>" + + " <transformerModelPath>files/model.onnx</transformerModelPath>" + + " <transformerModelUrl></transformerModelUrl>" + + " </config>" + + "</component>"; + assertTransform(embedder, component, false); + } + + @Test + public void testPredefinedEmbedConfigCloud() throws IOException, SAXException { + String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" />"; + String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>some url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>some url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; + assertTransform(embedder, component, true); + + embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"my_model_id\" />" + + " <vocab id=\"my_vocab_id\" />" + + "</embedder>"; + assertTransformThrows(embedder, "Unknown model id: 'my_vocab_id'", true); + + embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" + + " <model id=\"test-model-id\" />" + + " <vocab id=\"test-model-id\" />" + + "</embedder>"; + component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" + + " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" + + " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" + + " <tokenizerVocabPath></tokenizerVocabPath>" + + " <transformerModelUrl>test-model-url</transformerModelUrl>" + + " <transformerModelPath></transformerModelPath>" + + " </config>" + + "</component>"; + assertTransform(embedder, component, true); + } + + @Test + public void testEmbedConfig() throws Exception { + final String emptyPathFileName = "services.xml"; + + Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); + VespaModel model = loadModel(applicationDir, false); + ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + + Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test")); + ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test")); + assertEquals(testConfig.getObject("num").getValue(), "12"); + assertEquals(testConfig.getObject("str").getValue(), "some text"); + + Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); + ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); + assertEquals(transformerConfig.getObject("transformerModelUrl").getValue(), "test-model-url"); + assertEquals(transformerConfig.getObject("transformerModelPath").getValue(), emptyPathFileName); + assertEquals(transformerConfig.getObject("tokenizerVocabUrl").getValue(), ""); + assertEquals(transformerConfig.getObject("tokenizerVocabPath").getValue(), "files/vocab.txt"); + } + + private VespaModel loadModel(Path path, boolean hosted) throws Exception { + FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); + TestProperties properties = new TestProperties().setHostedVespa(hosted); + DeployState state = new DeployState.Builder().properties(properties).applicationPackage(applicationPackage).build(); + return new VespaModel(state); + } + + private void assertTransform(String embedder, String component) throws IOException, SAXException { + assertTransform(embedder, component, false); + } + + private void assertTransform(String embedder, String component, boolean hosted) throws IOException, SAXException { + Element emb = createElement(embedder); + Element cmp = createElement(component); + Element trans = EmbedderConfig.transform(createEmptyDeployState(hosted), emb); + assertSpec(cmp, trans); + } + + private void assertSpec(Element e1, Element e2) { + assertEquals(e1.getTagName(), e2.getTagName()); + assertAttributes(e1, e2); + assertAttributes(e2, e1); + assertChildren(e1, e2); + } + + private void assertAttributes(Element e1, Element e2) { + NamedNodeMap map = e1.getAttributes(); + for (int i = 0; i < map.getLength(); ++i) { + String attr = map.item(i).getNodeName(); + assertEquals(e1.getAttribute(attr), e2.getAttribute(attr)); + } + } + + private void assertChildren(Element e1, Element e2) { + List<Element> list1 = XML.getChildren(e1); + List<Element> list2 = XML.getChildren(e2); + assertEquals(list1.size(), list2.size()); + for (int i = 0; i < list1.size(); ++i) { + Element child1 = list1.get(i); + Element child2 = list2.get(i); + assertSpec(child1, child2); + } + } + + private void assertTransformThrows(String embedder, String msg) throws IOException, SAXException { + assertTransformThrows(embedder, msg, false); + } + + private void assertTransformThrows(String embedder, String msg, boolean hosted) throws IOException, SAXException { + try { + EmbedderConfig.transform(createEmptyDeployState(hosted), createElement(embedder)); + fail("Expected exception was not thrown: " + msg); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), msg); + } + } + + private Element createElement(String xml) throws IOException, SAXException { + Document doc = XML.getDocumentBuilder().parse(new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8))); + return (Element) doc.getFirstChild(); + } + + private DeployState createEmptyDeployState(boolean hosted) { + TestProperties properties = new TestProperties().setHostedVespa(hosted); + return new DeployState.Builder().properties(properties).build(); + } + +} diff --git a/configdefinitions/src/vespa/embedding.bert-base-embedder.def b/configdefinitions/src/vespa/embedding.bert-base-embedder.def index a37599de411..115e021972c 100644 --- a/configdefinitions/src/vespa/embedding.bert-base-embedder.def +++ b/configdefinitions/src/vespa/embedding.bert-base-embedder.def @@ -1,8 +1,13 @@ namespace=embedding +# Settings for wordpiece tokenizer +tokenizerVocabUrl url +tokenizerVocabPath path + # Transformer model settings -transformerModelUrl url default=https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx +transformerModelUrl url +transformerModelPath path # Max length of token sequence model can handle transformerMaxTokens int default=384 @@ -23,5 +28,3 @@ onnxExecutionMode enum { parallel, sequential } default=sequential onnxInterOpThreads int default=1 onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n -# Settings for wordpiece tokenizer -tokenizerVocabUrl url default=https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index 1831903d626..bc3f08ce3d6 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -11,6 +11,8 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.io.File; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -58,13 +60,22 @@ public class BertBaseEmbedder implements Embedder { options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads())); options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); - // Todo: use either file or url - tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build(); - evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options); + String tokenizerFile = pathOrUrl(config.tokenizerVocabPath(), config.tokenizerVocabUrl()); + String modelFile = pathOrUrl(config.transformerModelPath(), config.transformerModelUrl()); + + tokenizer = new WordPieceEmbedder.Builder(tokenizerFile).build(); + evaluator = new OnnxEvaluator(modelFile, options); validateModel(); } + private String pathOrUrl(Path path, File url) { + if (path.endsWith("services.xml")) { + return url.getAbsolutePath(); + } + return path.toAbsolutePath().toString(); + } + private void validateModel() { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index 464e5941e89..c224b87982d 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,6 +1,7 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import com.yahoo.config.FileReference; import com.yahoo.config.UrlReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; @@ -14,8 +15,6 @@ import static org.junit.Assume.assumeTrue; public class BertBaseEmbedderTest { - - @Test public void testEmbedder() { String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; @@ -23,8 +22,10 @@ public class BertBaseEmbedderTest { assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); - builder.tokenizerVocabUrl(new UrlReference(vocabPath)); - builder.transformerModelUrl(new UrlReference(modelPath)); + builder.tokenizerVocabPath(new FileReference(vocabPath)); + builder.tokenizerVocabUrl(new UrlReference("")); + builder.transformerModelPath(new FileReference(modelPath)); + builder.transformerModelUrl(new UrlReference("")); BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); |