summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java4
-rw-r--r--bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java7
-rw-r--r--bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java4
-rw-r--r--bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java10
-rw-r--r--bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java13
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java79
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java40
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java102
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java184
-rw-r--r--config-model/src/main/resources/schema/containercluster.rnc9
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml26
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java223
-rw-r--r--configdefinitions/src/vespa/embedding.bert-base-embedder.def9
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java17
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java9
16 files changed, 735 insertions, 11 deletions
diff --git a/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java b/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java
index 307509f0452..c5522c4c96a 100644
--- a/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java
+++ b/bundle-plugin/src/main/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassVisitor.java
@@ -160,7 +160,9 @@ class AnalyzeClassVisitor extends ClassVisitor implements ImportCollector {
if (ExportPackage.class.getName().equals(Type.getType(desc).getClassName())) {
return visitExportPackage();
} else {
- addImportWithTypeDesc(desc);
+ if (visible) {
+ addImportWithTypeDesc(desc);
+ }
return Analyze.visitAnnotationDefault(this);
}
}
diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java
index 6b9b0845328..1c061ca49a2 100644
--- a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java
+++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/AnalyzeClassTest.java
@@ -3,8 +3,10 @@ package com.yahoo.container.plugin.classanalysis;
import com.yahoo.container.plugin.classanalysis.sampleclasses.Base;
import com.yahoo.container.plugin.classanalysis.sampleclasses.ClassAnnotation;
+import com.yahoo.container.plugin.classanalysis.sampleclasses.InvisibleAnnotation;
import com.yahoo.container.plugin.classanalysis.sampleclasses.Derived;
import com.yahoo.container.plugin.classanalysis.sampleclasses.DummyAnnotation;
+import com.yahoo.container.plugin.classanalysis.sampleclasses.InvisibleDummyAnnotation;
import com.yahoo.container.plugin.classanalysis.sampleclasses.Fields;
import com.yahoo.container.plugin.classanalysis.sampleclasses.Interface1;
import com.yahoo.container.plugin.classanalysis.sampleclasses.Interface2;
@@ -109,6 +111,11 @@ public class AnalyzeClassTest {
}
@Test
+ public void invisible_annotation_not_included() {
+ assertFalse(analyzeClass(InvisibleAnnotation.class).getReferencedClasses().contains(name(InvisibleDummyAnnotation.class)));
+ }
+
+ @Test
public void method_annotation_is_included() {
assertTrue(analyzeClass(MethodAnnotation.class).getReferencedClasses().contains(name(DummyAnnotation.class)));
}
diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java
index a07b2917bd4..42335da3d62 100644
--- a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java
+++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/DummyAnnotation.java
@@ -1,9 +1,13 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.plugin.classanalysis.sampleclasses;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
/**
* Input for class analysis tests.
* @author Tony Vaagenes
*/
+@Retention(RetentionPolicy.RUNTIME)
public @interface DummyAnnotation {
}
diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java
new file mode 100644
index 00000000000..ced7c3305b0
--- /dev/null
+++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleAnnotation.java
@@ -0,0 +1,10 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.plugin.classanalysis.sampleclasses;
+
+/**
+ * Input for class analysis tests.*
+ * @author arnej
+ */
+@InvisibleDummyAnnotation
+public class InvisibleAnnotation {
+}
diff --git a/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java
new file mode 100644
index 00000000000..b3cb75df354
--- /dev/null
+++ b/bundle-plugin/src/test/java/com/yahoo/container/plugin/classanalysis/sampleclasses/InvisibleDummyAnnotation.java
@@ -0,0 +1,13 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.plugin.classanalysis.sampleclasses;
+
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
+/**
+ * Input for class analysis tests.
+ * @author arnej
+ */
+@Retention(RetentionPolicy.CLASS)
+public @interface InvisibleDummyAnnotation {
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 0bd93c6d0df..1121a90693b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -86,6 +86,7 @@ import com.yahoo.vespa.model.container.search.GUIHandler;
import com.yahoo.vespa.model.container.search.PageTemplates;
import com.yahoo.vespa.model.container.search.searchchain.SearchChains;
import com.yahoo.vespa.model.container.xml.document.DocumentFactoryBuilder;
+import com.yahoo.vespa.model.container.xml.embedder.EmbedderConfig;
import com.yahoo.vespa.model.content.StorageGroup;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
@@ -197,9 +198,11 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
private void addClusterContent(ApplicationContainerCluster cluster, Element spec, ConfigModelContext context) {
DeployState deployState = context.getDeployState();
DocumentFactoryBuilder.buildDocumentFactories(cluster, spec);
+
addConfiguredComponents(deployState, cluster, spec);
addSecretStore(cluster, spec, deployState);
+ addEmbedderComponents(deployState, cluster, spec);
addModelEvaluation(spec, cluster, context);
addModelEvaluationBundles(cluster);
@@ -382,6 +385,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
}
}
+ private static void addEmbedderComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) {
+ for (Element node : XML.getChildren(spec, "embedder")) {
+ Element transformed = EmbedderConfig.transform(deployState, node);
+ cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, transformed));
+ }
+ }
+
private void addConfiguredComponents(DeployState deployState, ApplicationContainerCluster cluster, Element spec) {
for (Element components : XML.getChildren(spec, "components")) {
addIncludes(components);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
new file mode 100644
index 00000000000..a2286647cdd
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfig.java
@@ -0,0 +1,79 @@
+package com.yahoo.vespa.model.container.xml.embedder;
+
+import com.yahoo.config.model.deploy.DeployState;
+import org.w3c.dom.Element;
+import org.w3c.dom.Node;
+import org.w3c.dom.NodeList;
+
+/**
+ * Translates config in services.xml of the form
+ *
+ * &lt;embedder id="..." class="..." bundle="..." def="..."&gt;
+ * &lt;!-- options --&gt;
+ * &lt;/embedder&gt;
+ *
+ * to component configuration of the form
+ *
+ * &lt;component id="..." class="..." bundle="..."&gt;
+ * &lt;config name=def&gt;
+ * &lt;!-- options --&gt;
+ * &lt;/config&gt;
+ * &lt;/component&gt;
+ *
+ * with some added interpretations based on recognizing the class.
+ *
+ * @author lesters
+ */
+public class EmbedderConfig {
+
+ static EmbedderConfigTransformer getEmbedderTransformer(Element spec, boolean hosted) {
+ String classId = getEmbedderClass(spec);
+ switch (classId) {
+ case "ai.vespa.embedding.BertBaseEmbedder": return new EmbedderConfigBertBaseTransformer(spec, hosted);
+ }
+ return new EmbedderConfigTransformer(spec, hosted);
+ }
+
+ static String modelIdToUrl(String id) {
+ switch (id) {
+ case "test-model-id":
+ return "test-model-url";
+ case "minilm-l6-v2":
+ return "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx";
+ case "bert-base-uncased":
+ return "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt";
+ }
+ throw new IllegalArgumentException("Unknown model id: '" + id + "'");
+ }
+
+ /**
+ * Transforms the &lt;embedder ...&gt; element to component configuration.
+ *
+ * @param deployState the deploy state - as config generation can depend on context
+ * @param embedderSpec the XML element containing the &lt;embedder ...&gt;
+ * @return a new XML element containting the &lt;component ...&gt; configuration
+ */
+ public static Element transform(DeployState deployState, Element embedderSpec) {
+ EmbedderConfigTransformer transformer = getEmbedderTransformer(embedderSpec, deployState.isHosted());
+ NodeList children = embedderSpec.getChildNodes();
+ for (int i = 0; i < children.getLength(); i++) {
+ Node child = children.item(i);
+ if (child instanceof Element) {
+ transformer.addOption((Element) child);
+ }
+ }
+ return transformer.createComponentConfig(deployState);
+ }
+
+ private static String getEmbedderClass(Element spec) {
+ if (spec.hasAttribute("class")) {
+ return spec.getAttribute("class");
+ }
+ if (spec.hasAttribute("id")) {
+ return spec.getAttribute("id");
+ }
+ throw new IllegalArgumentException("Embedder specification does not have a required class attribute");
+ }
+
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java
new file mode 100644
index 00000000000..9431926d088
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigBertBaseTransformer.java
@@ -0,0 +1,40 @@
+package com.yahoo.vespa.model.container.xml.embedder;
+
+import org.w3c.dom.Element;
+
+import java.util.Map;
+
+/**
+ * Transforms embedding configuration to component configuration for the
+ * BertBaseEmbedder using embedder.bert-base-embedder.def
+ *
+ * @author lesters
+ */
+public class EmbedderConfigBertBaseTransformer extends EmbedderConfigTransformer {
+
+ private static final String BUNDLE = "model-integration";
+ private static final String DEF = "embedding.bert-base-embedder";
+
+ public EmbedderConfigBertBaseTransformer(Element spec, boolean hosted) {
+ super(spec, hosted, BUNDLE, DEF);
+
+ EmbedderOption.Builder modelOption = new EmbedderOption.Builder()
+ .name("model")
+ .required(true)
+ .optionTransformer(new EmbedderOption.ModelOptionTransformer("transformerModelPath", "transformerModelUrl"));
+ EmbedderOption.Builder vocabOption = new EmbedderOption.Builder()
+ .name("vocab")
+ .required(true)
+ .optionTransformer(new EmbedderOption.ModelOptionTransformer("tokenizerVocabPath", "tokenizerVocabUrl"));
+
+ // Defaults
+ if (hosted) {
+ modelOption.attributes(Map.of("id", "minilm-l6-v2")).value("");
+ vocabOption.attributes(Map.of("id", "bert-base-uncased")).value("");
+ }
+
+ addOption(modelOption.build());
+ addOption(vocabOption.build());
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java
new file mode 100644
index 00000000000..30327fdc8af
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderConfigTransformer.java
@@ -0,0 +1,102 @@
+package com.yahoo.vespa.model.container.xml.embedder;
+
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.text.XML;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+
+/**
+ * A specific embedder to component configuration transformer.
+ *
+ * @author lesters
+ */
+public class EmbedderConfigTransformer {
+
+ private final Document doc = XML.getDocumentBuilder().newDocument();
+
+ private final String id;
+ private final String className;
+ private final String bundle;
+ private final String def;
+ private final Map<String, EmbedderOption> options = new HashMap<>();
+
+ public EmbedderConfigTransformer(Element spec, boolean hosted) {
+ this(spec, hosted, null, null);
+ }
+
+ public EmbedderConfigTransformer(Element spec, boolean hosted, String defaultBundle, String defaultDef) {
+ id = spec.getAttribute("id");
+ className = spec.hasAttribute("class") ? spec.getAttribute("class") : id;
+ bundle = spec.hasAttribute("bundle") ? spec.getAttribute("bundle") : defaultBundle;
+ def = spec.hasAttribute("def") ? spec.getAttribute("def") : defaultDef;
+
+ if (className == null || className.length() == 0) {
+ throw new IllegalArgumentException("Embedder class is empty");
+ }
+ if (this.bundle == null || this.bundle.length() == 0) {
+ throw new IllegalArgumentException("Embedder configuration requires a bundle name");
+ }
+ if (this.def == null || this.def.length() == 0) {
+ throw new IllegalArgumentException("Embedder configuration requires a config definition name");
+ }
+ }
+
+ Element createComponentConfig(DeployState deployState) {
+ checkRequiredOptions();
+
+ Element component = doc.createElement("component");
+ component.setAttribute("id", id);
+ component.setAttribute("class", className);
+ component.setAttribute("bundle", bundle);
+
+ if (options.size() > 0) {
+ Element config = doc.createElement("config");
+ config.setAttribute("name", def);
+ for (Map.Entry<String, EmbedderOption> entry : options.entrySet()) {
+ entry.getValue().toElement(deployState, config);
+ }
+ component.appendChild(config);
+ }
+
+ return component;
+ }
+
+ // TODO: support nested options
+ void addOption(Element elem) {
+ String name = elem.getTagName();
+
+ EmbedderOption.Builder builder = new EmbedderOption.Builder();
+ builder.name(name);
+ builder.value(elem.getTextContent());
+ builder.attributes(elem);
+
+ if (options.containsKey(name)) {
+ builder.required(options.get(name).required());
+ builder.optionTransformer(options.get(name).optionTransformer());
+ }
+ options.put(name, builder.build());
+ }
+
+ void addOption(EmbedderOption option) {
+ options.put(option.name(), option);
+ }
+
+ private void checkRequiredOptions() {
+ List<String> missingOptions = new ArrayList<>();
+ for (EmbedderOption option : options.values()) {
+ if ( ! option.isSet()) {
+ missingOptions.add(option.name());
+ }
+ }
+ if (missingOptions.size() > 0) {
+ throw new IllegalArgumentException("Embedder '" + className + "' requires options for " + missingOptions);
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
new file mode 100644
index 00000000000..206745887d1
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/embedder/EmbedderOption.java
@@ -0,0 +1,184 @@
+package com.yahoo.vespa.model.container.xml.embedder;
+
+import com.yahoo.config.model.deploy.DeployState;
+import org.w3c.dom.Element;
+import org.w3c.dom.NamedNodeMap;
+
+import java.util.HashMap;
+import java.util.Map;
+
+
+/**
+ * Holds options for embedder configuration. This includes code for handling special
+ * options such as model specifiers.
+ *
+ * @author lesters
+ */
+public class EmbedderOption {
+
+ public static final OptionTransformer defaultOptionTransformer = new OptionTransformer();
+
+ private final String name;
+ private final boolean required;
+ private final String value;
+ private final Map<String, String> attributes;
+ private final OptionTransformer optionTransformer;
+ private final boolean set;
+
+ private EmbedderOption(Builder builder) {
+ this.name = builder.name;
+ this.required = builder.required;
+ this.value = builder.value;
+ this.attributes = builder.attributes;
+ this.optionTransformer = builder.optionTransformer;
+ this.set = builder.set;
+ }
+
+ public void toElement(DeployState deployState, Element parent) {
+ optionTransformer.transform(deployState, parent, this);
+ }
+
+ public String name() {
+ return name;
+ }
+
+ public String value() {
+ return value;
+ }
+
+ public boolean required() {
+ return required;
+ }
+
+ public OptionTransformer optionTransformer() {
+ return optionTransformer;
+ }
+
+ public boolean isSet() {
+ return set;
+ }
+
+ /**
+ * Basic option transformer. No special handling of options.
+ */
+ public static class OptionTransformer {
+ public void transform(DeployState deployState, Element parent, EmbedderOption option) {
+ createElement(parent, option.name(), option.value());
+ }
+
+ public static Element createElement(Element parent, String name, String value) {
+ Element element = parent.getOwnerDocument().createElement(name);
+ element.setTextContent(value);
+ parent.appendChild(element);
+ return element;
+ }
+ }
+
+ /**
+ * Transforms model options of type &lt;x id="..." url="..." path="..." /&gt; to the
+ * required fields in the config definition.
+ */
+ public static class ModelOptionTransformer extends OptionTransformer {
+
+ private final String pathField;
+ private final String urlField;
+
+ public ModelOptionTransformer(String pathField, String urlField) {
+ super();
+ this.pathField = pathField;
+ this.urlField = urlField;
+ }
+
+ @Override
+ public void transform(DeployState deployState, Element parent, EmbedderOption option) {
+ String id = option.attributes.get("id");
+ String url = option.attributes.get("url");
+ String path = option.attributes.get("path");
+
+ // Always use path if it is set
+ if (path != null && path.length() > 0) {
+ createElement(parent, pathField, path);
+ createElement(parent, urlField, "");
+ return;
+ }
+
+ // Only use the id if we're on cloud
+ if (deployState.isHosted() && id != null && id.length() > 0) {
+ createElement(parent, urlField, EmbedderConfig.modelIdToUrl(id));
+ createElement(parent, pathField, createDummyPath(deployState));
+ return;
+ }
+
+ // Otherwise, use url
+ if (url != null && url.length() > 0) {
+ createElement(parent, urlField, url);
+ createElement(parent, pathField, createDummyPath(deployState));
+ return;
+ }
+
+ if ( ! deployState.isHosted() && id != null && id.length() > 0) {
+ throw new IllegalArgumentException("Model option 'id' is not valid here");
+ }
+ throw new IllegalArgumentException("Model option requires either a 'path' or a 'url' attribute");
+ }
+
+ private String createDummyPath(DeployState deployState) {
+ // For now, until we have optional config parameters, return services.xml as it is guaranteed to exist
+ return "services.xml";
+ }
+
+ }
+
+ public static class Builder {
+ private String name = "";
+ private boolean required = false;
+ private String value = "";
+ private Map<String, String> attributes = Map.of();
+ private OptionTransformer optionTransformer = defaultOptionTransformer;
+ private boolean set = false;
+
+ public Builder name(String name) {
+ this.name = name;
+ return this;
+ }
+
+ public Builder required(boolean required) {
+ this.required = required;
+ return this;
+ }
+
+ public Builder value(String value) {
+ this.value = value;
+ this.set = true;
+ return this;
+ }
+
+ public Builder attributes(Map<String, String> attributes) {
+ this.attributes = attributes;
+ return this;
+ }
+
+ public Builder attributes(Element element) {
+ NamedNodeMap map = element.getAttributes();
+ if (map.getLength() > 0) {
+ this.attributes = new HashMap<>(map.getLength());
+ for (int i = 0; i < map.getLength(); ++i) {
+ String attr = map.item(i).getNodeName();
+ attributes.put(attr, element.getAttribute(attr));
+ }
+ }
+ return this;
+ }
+
+ public Builder optionTransformer(OptionTransformer optionTransformer) {
+ this.optionTransformer = optionTransformer;
+ return this;
+ }
+
+ public EmbedderOption build() {
+ return new EmbedderOption(this);
+ }
+
+ }
+
+}
diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc
index e8ff1721397..3fdbff84f6d 100644
--- a/config-model/src/main/resources/schema/containercluster.rnc
+++ b/config-model/src/main/resources/schema/containercluster.rnc
@@ -19,6 +19,7 @@ ContainerServices =
DocumentApi? &
Components* &
Component* &
+ Embedder* &
Handler* &
Client* &
Server* &
@@ -103,6 +104,14 @@ ZooKeeper = element zookeeper {
empty
}
+Embedder = element embedder {
+ attribute id { string }? &
+ attribute class { xsd:Name | JavaId }? &
+ attribute bundle { xsd:Name }? &
+ attribute def { xsd:Name }? &
+ anyElement*
+}
+
ModelEvaluation = element model-evaluation {
element onnx {
element models {
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
new file mode 100644
index 00000000000..f319d875ed9
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -0,0 +1,26 @@
+<?xml version="1.0" encoding="utf-8" ?>
+<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services version="1.0">
+
+ <container version="1.0">
+
+ <embedder id="test" class="ai.vespa.embedding.UndefinedEmbedder" bundle="dummy" def="test.dummy">
+ <num>12</num>
+ <str>some text</str>
+ </embedder>
+
+ <embedder id="transformer" class="ai.vespa.embedding.BertBaseEmbedder">
+ <!-- model specifics -->
+ <model id="test-model-id" url="test-model-url"/>
+ <vocab path="files/vocab.txt"/>
+
+ <!-- tunable parameters: number of threads etc -->
+ <onnxIntraOpThreads>4</onnxIntraOpThreads>
+ </embedder>
+
+ <nodes>
+ <node hostalias="node1" />
+ </nodes>
+ </container>
+
+</services>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
new file mode 100644
index 00000000000..0dae86473c8
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -0,0 +1,223 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.xml;
+
+import com.yahoo.component.ComponentId;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.path.Path;
+import com.yahoo.text.XML;
+import com.yahoo.vespa.config.ConfigDefinitionKey;
+import com.yahoo.vespa.config.ConfigPayloadBuilder;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+import com.yahoo.vespa.model.container.component.Component;
+import com.yahoo.vespa.model.container.xml.embedder.EmbedderConfig;
+import org.junit.Test;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
+import org.w3c.dom.NamedNodeMap;
+import org.xml.sax.SAXException;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class EmbedderTestCase {
+
+ private static final String PREDEFINED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder";
+ private static final String PREDEFINED_EMBEDDER_CONFIG = "embedding.bert-base-embedder";
+
+ @Test
+ public void testGenericEmbedConfig() throws IOException, SAXException {
+ String embedder = "<embedder id=\"test\" class=\"ai.vespa.test\" bundle=\"bundle\" def=\"def.name\">" +
+ " <val>123</val>" +
+ "</embedder>";
+ String component = "<component id=\"test\" class=\"ai.vespa.test\" bundle=\"bundle\">" +
+ " <config name=\"def.name\">" +
+ " <val>123</val>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component);
+ }
+
+ @Test
+ public void testGenericEmbedConfigRequiresBundleAndDef() throws IOException, SAXException {
+ assertTransformThrows("<embedder id=\"test\" class=\"ai.vespa.test\"></embedder>",
+ "Embedder configuration requires a bundle name");
+ assertTransformThrows("<embedder id=\"test\" class=\"ai.vespa.test\" bundle=\"bundle\"></embedder>",
+ "Embedder configuration requires a config definition name");
+ }
+
+ @Test
+ public void testPredefinedEmbedConfigSelfHosted() throws IOException, SAXException {
+ assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\"></embedder>",
+ "Embedder '" + PREDEFINED_EMBEDDER_CLASS + "' requires options for [vocab, model]");
+ assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model />" +
+ " <vocab />" +
+ "</embedder>",
+ "Model option requires either a 'path' or a 'url' attribute");
+ assertTransformThrows("<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"my_model_id\" />" +
+ " <vocab id=\"my_vocab_id\" />" +
+ "</embedder>",
+ "Model option 'id' is not valid here");
+
+ String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"my_model_id\" url=\"my-model-url\" />" +
+ " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" />" +
+ "</embedder>";
+ String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>my-model-url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component, false);
+
+ // Path has priority:
+ embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"my_model_id\" url=\"my-model-url\" path=\"files/model.onnx\" />" +
+ " <vocab id=\"my_vocab_id\" url=\"my-vocab-url\" path=\"files/vocab.txt\" />" +
+ "</embedder>";
+ component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" +
+ " <tokenizerVocabUrl></tokenizerVocabUrl>" +
+ " <transformerModelPath>files/model.onnx</transformerModelPath>" +
+ " <transformerModelUrl></transformerModelUrl>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component, false);
+ }
+
+ @Test
+ public void testPredefinedEmbedConfigCloud() throws IOException, SAXException {
+ String embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" />";
+ String component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>some url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>some url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component, true);
+
+ embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"my_model_id\" />" +
+ " <vocab id=\"my_vocab_id\" />" +
+ "</embedder>";
+ assertTransformThrows(embedder, "Unknown model id: 'my_vocab_id'", true);
+
+ embedder = "<embedder id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\">" +
+ " <model id=\"test-model-id\" />" +
+ " <vocab id=\"test-model-id\" />" +
+ "</embedder>";
+ component = "<component id=\"test\" class=\"" + PREDEFINED_EMBEDDER_CLASS + "\" bundle=\"model-integration\">" +
+ " <config name=\"" + PREDEFINED_EMBEDDER_CONFIG + "\">" +
+ " <tokenizerVocabUrl>test-model-url</tokenizerVocabUrl>" +
+ " <tokenizerVocabPath></tokenizerVocabPath>" +
+ " <transformerModelUrl>test-model-url</transformerModelUrl>" +
+ " <transformerModelPath></transformerModelPath>" +
+ " </config>" +
+ "</component>";
+ assertTransform(embedder, component, true);
+ }
+
+ @Test
+ public void testEmbedConfig() throws Exception {
+ final String emptyPathFileName = "services.xml";
+
+ Path applicationDir = Path.fromString("src/test/cfg/application/embed/");
+ VespaModel model = loadModel(applicationDir, false);
+ ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container");
+
+ Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("test"));
+ ConfigPayloadBuilder testConfig = testComponent.getUserConfigs().get(new ConfigDefinitionKey("dummy", "test"));
+ assertEquals(testConfig.getObject("num").getValue(), "12");
+ assertEquals(testConfig.getObject("str").getValue(), "some text");
+
+ Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
+ ConfigPayloadBuilder transformerConfig = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
+ assertEquals(transformerConfig.getObject("transformerModelUrl").getValue(), "test-model-url");
+ assertEquals(transformerConfig.getObject("transformerModelPath").getValue(), emptyPathFileName);
+ assertEquals(transformerConfig.getObject("tokenizerVocabUrl").getValue(), "");
+ assertEquals(transformerConfig.getObject("tokenizerVocabPath").getValue(), "files/vocab.txt");
+ }
+
+ private VespaModel loadModel(Path path, boolean hosted) throws Exception {
+ FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile());
+ TestProperties properties = new TestProperties().setHostedVespa(hosted);
+ DeployState state = new DeployState.Builder().properties(properties).applicationPackage(applicationPackage).build();
+ return new VespaModel(state);
+ }
+
+ private void assertTransform(String embedder, String component) throws IOException, SAXException {
+ assertTransform(embedder, component, false);
+ }
+
+ private void assertTransform(String embedder, String component, boolean hosted) throws IOException, SAXException {
+ Element emb = createElement(embedder);
+ Element cmp = createElement(component);
+ Element trans = EmbedderConfig.transform(createEmptyDeployState(hosted), emb);
+ assertSpec(cmp, trans);
+ }
+
+ private void assertSpec(Element e1, Element e2) {
+ assertEquals(e1.getTagName(), e2.getTagName());
+ assertAttributes(e1, e2);
+ assertAttributes(e2, e1);
+ assertChildren(e1, e2);
+ }
+
+ private void assertAttributes(Element e1, Element e2) {
+ NamedNodeMap map = e1.getAttributes();
+ for (int i = 0; i < map.getLength(); ++i) {
+ String attr = map.item(i).getNodeName();
+ assertEquals(e1.getAttribute(attr), e2.getAttribute(attr));
+ }
+ }
+
+ private void assertChildren(Element e1, Element e2) {
+ List<Element> list1 = XML.getChildren(e1);
+ List<Element> list2 = XML.getChildren(e2);
+ assertEquals(list1.size(), list2.size());
+ for (int i = 0; i < list1.size(); ++i) {
+ Element child1 = list1.get(i);
+ Element child2 = list2.get(i);
+ assertSpec(child1, child2);
+ }
+ }
+
+ private void assertTransformThrows(String embedder, String msg) throws IOException, SAXException {
+ assertTransformThrows(embedder, msg, false);
+ }
+
+ private void assertTransformThrows(String embedder, String msg, boolean hosted) throws IOException, SAXException {
+ try {
+ EmbedderConfig.transform(createEmptyDeployState(hosted), createElement(embedder));
+ fail("Expected exception was not thrown: " + msg);
+ } catch (IllegalArgumentException e) {
+ assertEquals(e.getMessage(), msg);
+ }
+ }
+
+ private Element createElement(String xml) throws IOException, SAXException {
+ Document doc = XML.getDocumentBuilder().parse(new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8)));
+ return (Element) doc.getFirstChild();
+ }
+
+ private DeployState createEmptyDeployState(boolean hosted) {
+ TestProperties properties = new TestProperties().setHostedVespa(hosted);
+ return new DeployState.Builder().properties(properties).build();
+ }
+
+}
diff --git a/configdefinitions/src/vespa/embedding.bert-base-embedder.def b/configdefinitions/src/vespa/embedding.bert-base-embedder.def
index a37599de411..115e021972c 100644
--- a/configdefinitions/src/vespa/embedding.bert-base-embedder.def
+++ b/configdefinitions/src/vespa/embedding.bert-base-embedder.def
@@ -1,8 +1,13 @@
namespace=embedding
+# Settings for wordpiece tokenizer
+tokenizerVocabUrl url
+tokenizerVocabPath path
+
# Transformer model settings
-transformerModelUrl url default=https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx
+transformerModelUrl url
+transformerModelPath path
# Max length of token sequence model can handle
transformerMaxTokens int default=384
@@ -23,5 +28,3 @@ onnxExecutionMode enum { parallel, sequential } default=sequential
onnxInterOpThreads int default=1
onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
-# Settings for wordpiece tokenizer
-tokenizerVocabUrl url default=https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 1831903d626..bc3f08ce3d6 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -11,6 +11,8 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.io.File;
+import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -58,13 +60,22 @@ public class BertBaseEmbedder implements Embedder {
options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads()));
options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads()));
- // Todo: use either file or url
- tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build();
- evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options);
+ String tokenizerFile = pathOrUrl(config.tokenizerVocabPath(), config.tokenizerVocabUrl());
+ String modelFile = pathOrUrl(config.transformerModelPath(), config.transformerModelUrl());
+
+ tokenizer = new WordPieceEmbedder.Builder(tokenizerFile).build();
+ evaluator = new OnnxEvaluator(modelFile, options);
validateModel();
}
+ private String pathOrUrl(Path path, File url) {
+ if (path.endsWith("services.xml")) {
+ return url.getAbsolutePath();
+ }
+ return path.toAbsolutePath().toString();
+ }
+
private void validateModel() {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index 464e5941e89..c224b87982d 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,6 +1,7 @@
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.config.FileReference;
import com.yahoo.config.UrlReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
@@ -14,8 +15,6 @@ import static org.junit.Assume.assumeTrue;
public class BertBaseEmbedderTest {
-
-
@Test
public void testEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
@@ -23,8 +22,10 @@ public class BertBaseEmbedderTest {
assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
- builder.tokenizerVocabUrl(new UrlReference(vocabPath));
- builder.transformerModelUrl(new UrlReference(modelPath));
+ builder.tokenizerVocabPath(new FileReference(vocabPath));
+ builder.tokenizerVocabUrl(new UrlReference(""));
+ builder.transformerModelPath(new FileReference(modelPath));
+ builder.transformerModelUrl(new UrlReference(""));
BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");