summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-08-31 22:50:14 +0200
committerJon Bratseth <bratseth@gmail.com>2022-08-31 22:50:14 +0200
commitadcb1d4d55e71d78c662f798b033d3abea0d4b9e (patch)
tree5867c3ac85792c1578d6ce463e8e24dd2aea7fb0 /config-model
parent2b83da619a3ee2f38a1a3b05576f44d7451b3daf (diff)
Add 'model' config type
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilder.java41
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java73
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java51
-rw-r--r--config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def30
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml2
-rw-r--r--config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def9
-rw-r--r--config-model/src/test/cfg/application/embed_generic/services.xml2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilderTest.java4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java105
10 files changed, 143 insertions, 176 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilder.java
index 7ff01cbf82e..9390986c0c4 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilder.java
@@ -52,19 +52,19 @@ public class DomConfigPayloadBuilder {
public static ConfigDefinitionKey parseConfigName(Element configE) {
if (!configE.getNodeName().equals("config")) {
- throw new ConfigurationRuntimeException("The root element must be 'config', but was '" + configE.getNodeName() + "'.");
+ throw new ConfigurationRuntimeException("The root element must be 'config', but was '" + configE.getNodeName() + "'");
}
if (!configE.hasAttribute("name")) {
throw new ConfigurationRuntimeException
- ("The 'config' element must have a 'name' attribute that matches the name of the config definition.");
+ ("The 'config' element must have a 'name' attribute that matches the name of the config definition");
}
String elementString = configE.getAttribute("name");
if (!elementString.contains(".")) {
throw new ConfigurationRuntimeException("The config name '" + elementString +
- "' contains illegal characters. Only names with the pattern " +
- namespacePattern.pattern() + "." + namePattern.pattern() + " are legal.");
+ "' contains illegal characters. Only names with the pattern " +
+ namespacePattern.pattern() + "." + namePattern.pattern() + " are legal.");
}
Tuple2<String, String> t = ConfigUtils.getNameAndNamespaceFromString(elementString);
@@ -73,28 +73,26 @@ public class DomConfigPayloadBuilder {
if (!validName(xmlName)) {
throw new ConfigurationRuntimeException("The config name '" + xmlName +
- "' contains illegal characters. Only names with the pattern " + namePattern.toString() + " are legal.");
+ "' contains illegal characters. Only names with the pattern " +
+ namePattern.toString() + " are legal.");
}
if (!validNamespace(xmlNamespace)) {
throw new ConfigurationRuntimeException("The config namespace '" + xmlNamespace +
- "' contains illegal characters. Only namespaces with the pattern " + namespacePattern.toString() + " are legal.");
+ "' contains illegal characters. Only namespaces with the pattern " +
+ namespacePattern.toString() + " are legal.");
}
return new ConfigDefinitionKey(xmlName, xmlNamespace);
}
private static boolean validName(String name) {
if (name == null) return false;
-
- Matcher m = namePattern.matcher(name);
- return m.matches();
+ return namePattern.matcher(name).matches();
}
private static boolean validNamespace(String namespace) {
if (namespace == null) return false;
-
- Matcher m = namespacePattern.matcher(namespace);
- return m.matches();
+ return namespacePattern.matcher(namespace).matches();
}
private String extractName(Element element) {
@@ -118,12 +116,11 @@ public class DomConfigPayloadBuilder {
return buf.toString();
}
- /**
- * Parse leaf value in an xml tree
- */
+ /** Parse leaf value in an xml tree. */
private void parseLeaf(Element element, ConfigPayloadBuilder payloadBuilder, String parentName) {
String name = extractName(element);
String value = XML.getValue(element);
+ var definition = payloadBuilder.getConfigDefinition();
if (value == null) {
throw new ConfigurationRuntimeException("Element '" + name + "' must have either children or a value");
}
@@ -136,8 +133,14 @@ public class DomConfigPayloadBuilder {
} else {
payloadBuilder.getArray(parentName).append(value);
}
- } else {
- // leaf scalar, e.g. <intVal>3</intVal>
+ }
+ else if (definition != null && definition.getModelDefs().containsKey(name)) { // model field special syntax
+ String modelString = XML.attribute("model-id", element).orElse("\"\"");
+ modelString += " " + XML.attribute("url", element).orElse("\"\"");
+ modelString += " " + XML.attribute("path", element).orElse("\"\"");
+ payloadBuilder.setField(name, modelString);
+ }
+ else { // leaf value: <myValueName>value</myValue>
payloadBuilder.setField(name, value);
}
}
@@ -196,8 +199,8 @@ public class DomConfigPayloadBuilder {
parseComplex(currElem, children, payloadBuilder, parentName);
}
} catch (Exception exception) {
- throw new ConfigurationRuntimeException("Error parsing element at " + XML.getNodePath(currElem, " > ") + ": " +
- Exceptions.toMessageString(exception));
+ throw new ConfigurationRuntimeException("Error parsing element at " + XML.getNodePath(currElem, " > ") +
+ ": " + Exceptions.toMessageString(exception));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index fc8a542b81c..acd8b5cbbc2 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -957,7 +957,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
private static void addConfiguredComponents(DeployState deployState, ContainerCluster<? extends Container> cluster,
Element parent, String componentName) {
for (Element component : XML.getChildren(parent, componentName)) {
- component = ModelConfigTransformer.transform(deployState, component);
+ ModelIdResolver.resolveModelIds(component, deployState.isHosted());
cluster.addComponent(new DomComponentBuilder().build(deployState, cluster, component));
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java
deleted file mode 100644
index 0065a582145..00000000000
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelConfigTransformer.java
+++ /dev/null
@@ -1,73 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.vespa.model.container.xml;
-
-import com.yahoo.config.model.deploy.DeployState;
-import com.yahoo.text.XML;
-import org.w3c.dom.Element;
-
-import java.util.Map;
-import java.util.stream.Collectors;
-
-/**
- * Translates model references in component configs.
- *
- * @author lesters
- * @author bratseth
- */
-public class ModelConfigTransformer {
-
- private static final Map<String, String> providedModels =
- Map.of("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
- "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt");
-
- // Until we have optional path parameters, use services.xml as it is guaranteed to exist
- private final static String dummyPath = "services.xml";
-
- /**
- * Transforms the &lt;embedder ...&gt; element to component configuration.
- *
- * @param deployState the deploy state - as config generation can depend on context
- * @param component the XML element containing the &lt;embedder ...&gt;
- * @return a new XML element containting the &lt;component ...&gt; configuration
- */
- public static Element transform(DeployState deployState, Element component) {
- for (Element config : XML.getChildren(component, "config")) {
- for (Element value : XML.getChildren(config))
- transformModelValue(value, config, deployState.isHosted());
- }
- return component;
- }
-
- /** Expans a model config value into regular config values. */
- private static void transformModelValue(Element value, Element config, boolean hosted) {
- if (value.hasAttribute("path")) {
- addChild(value.getTagName() + "Url", "", config);
- addChild(value.getTagName() + "Path", value.getAttribute("path"), config);
- config.removeChild(value);
- }
- else if (value.hasAttribute("id") && hosted) {
- addChild(value.getTagName() + "Url", modelIdToUrl(value.getAttribute("id")), config);
- addChild(value.getTagName() + "Path", dummyPath, config);
- config.removeChild(value);
- }
- else if (value.hasAttribute("url")) {
- addChild(value.getTagName() + "Url", value.getAttribute("url"), config);
- addChild(value.getTagName() + "Path", dummyPath, config);
- config.removeChild(value);
- }
- }
-
- private static void addChild(String name, String value, Element parent) {
- Element element = parent.getOwnerDocument().createElement(name);
- element.setTextContent(value);
- parent.appendChild(element);
- }
-
- private static String modelIdToUrl(String id) {
- if ( ! providedModels.containsKey(id))
- throw new IllegalArgumentException("Unknown embedder model '" + id + "'. Available models are [" +
- providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]");
- return providedModels.get(id);
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
new file mode 100644
index 00000000000..be696832dd7
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -0,0 +1,51 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.xml;
+
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.text.XML;
+import org.w3c.dom.Element;
+
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Replaces model id references in configs by their url.
+ *
+ * @author lesters
+ * @author bratseth
+ */
+public class ModelIdResolver {
+
+ private static final Map<String, String> providedModels =
+ Map.of("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
+ "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt");
+
+ /**
+ * Finds any config values of type 'model' below the given config element and
+ * supplies the url attribute of them if a model id is specified and hosted is true
+ * (regardless of whether an url is already specified).
+ *
+ * @param component the XML element of any component
+ */
+ public static void resolveModelIds(Element component, boolean hosted) {
+ if ( ! hosted) return;
+ for (Element config : XML.getChildren(component, "config")) {
+ for (Element value : XML.getChildren(config))
+ transformModelValue(value);
+ }
+ }
+
+ /** Expans a model config value into regular config values. */
+ private static void transformModelValue(Element value) {
+ if ( ! value.hasAttribute("model-id")) return;
+ value.setAttribute("url", modelIdToUrl(value.getTagName(), value.getAttribute("model-id")));
+ }
+
+ private static String modelIdToUrl(String valueName, String modelId) {
+ if ( ! providedModels.containsKey(modelId))
+ throw new IllegalArgumentException("Unknown model id '" + modelId + "' on '" + valueName + "'. Available models are [" +
+ providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]");
+ return providedModels.get(modelId);
+ }
+
+}
diff --git a/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def b/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def
new file mode 100644
index 00000000000..a6544187140
--- /dev/null
+++ b/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def
@@ -0,0 +1,30 @@
+# Copy of this Vespa config stored here because Vespa config definitions are not
+# available in unit tests, and are needed (by DomConfigPayloadBuilder.parseLeaf)
+# Alternativ ely, we could make that not need it as it is not strictly necessaery.
+
+namespace=embedding
+
+# Wordpiece tokenizer
+tokenizerVocab model
+
+transformerModel model
+
+# Max length of token sequence model can handle
+transformerMaxTokens int default=384
+
+# Pooling strategy
+poolingStrategy enum { cls, mean } default=mean
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+transformerTokenTypeIds string default=token_type_ids
+
+# Output name
+transformerOutput string default=output_0
+
+# Settings for ONNX model evaluation
+onnxExecutionMode enum { parallel, sequential } default=sequential
+onnxInterOpThreads int default=1
+onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
+
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 88558ace4bf..cdbcfd67f02 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -7,7 +7,7 @@
<component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bindle="model-integration">
<config name="embedding.bert-base-embedder">
<!-- model specifics -->
- <transformerModel id="minilm-l6-v2" url="application-url"/>
+ <transformerModel model-id="minilm-l6-v2" url="application-url"/>
<tokenizerVocab path="files/vocab.txt"/>
<!-- tunable parameters: number of threads etc -->
diff --git a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
index 81fc88dbf01..87b80f1051a 100644
--- a/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
+++ b/config-model/src/test/cfg/application/embed_generic/configdefinitions/sentence-embedder.def
@@ -1,12 +1,9 @@
package=ai.vespa.example.paragraph
-# Settings for wordpiece tokenizer
-vocabPath path
-vocabUrl string
+# WordPiece tokenizer vocabulary
+vocab model
-# Transformer model settings
-modelPath path
-modelUrl string
+model model
myValue string
diff --git a/config-model/src/test/cfg/application/embed_generic/services.xml b/config-model/src/test/cfg/application/embed_generic/services.xml
index ea430f24e2f..d2c22c03343 100644
--- a/config-model/src/test/cfg/application/embed_generic/services.xml
+++ b/config-model/src/test/cfg/application/embed_generic/services.xml
@@ -8,7 +8,7 @@
class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder'
bundle='exampleEmbedder'>
<config name='ai.vespa.example.paragraph.sentence-embedder'>
- <model id="minilm-l6-v2" url="application-url" />
+ <model model-id="minilm-l6-v2" url="application-url" />
<vocab path="files/vocab.txt"/>
<myValue>foo</myValue>
</config>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilderTest.java
index 88af584de90..e788fe5fc54 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilderTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomConfigPayloadBuilderTest.java
@@ -130,7 +130,7 @@ public class DomConfigPayloadBuilderTest {
new DomConfigPayloadBuilder(null).build(configRoot);
fail("Expected exception for wrong tag name.");
} catch (ConfigurationRuntimeException e) {
- assertEquals("The root element must be 'config', but was 'configs'.", e.getMessage());
+ assertEquals("The root element must be 'config', but was 'configs'", e.getMessage());
}
}
@@ -142,7 +142,7 @@ public class DomConfigPayloadBuilderTest {
new DomConfigPayloadBuilder(null).build(configRoot);
fail("Expected exception for mismatch between def-name and xml name attribute.");
} catch (ConfigurationRuntimeException e) {
- assertEquals("The 'config' element must have a 'name' attribute that matches the name of the config definition.", e.getMessage());
+ assertEquals("The 'config' element must have a 'name' attribute that matches the name of the config definition", e.getMessage());
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index ffa7e52136f..60386be17db 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -28,7 +28,6 @@ import static org.junit.jupiter.api.Assertions.fail;
public class EmbedderTestCase {
- private static final String emptyPathFileName = "services.xml";
private static final String BUNDLED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder";
private static final String BUNDLED_EMBEDDER_CONFIG = "embedding.bert-base-embedder";
@@ -42,29 +41,8 @@ public class EmbedderTestCase {
"</component>";
String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
" <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModelUrl>my-model-url</transformerModelUrl>" +
- " <transformerModelPath>services.xml</transformerModelPath>" +
- " <tokenizerVocabUrl>my-vocab-url</tokenizerVocabUrl>" +
- " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" +
- " </config>" +
- "</component>";
- assertTransform(input, component, false);
- }
-
- @Test
- void testPathHasPriority_selfhosted() throws IOException, SAXException {
- String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel id='my_model_id' url='my-model-url' path='files/model.onnx' />" +
- " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' path='files/vocab.txt' />" +
- " </config>" +
- "</component>";
- String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
- " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModelUrl></transformerModelUrl>" +
- " <transformerModelPath>files/model.onnx</transformerModelPath>" +
- " <tokenizerVocabUrl></tokenizerVocabUrl>" +
- " <tokenizerVocabPath>files/vocab.txt</tokenizerVocabPath>" +
+ " <transformerModel id='my_model_id' url='my-model-url' />" +
+ " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" +
" </config>" +
"</component>";
assertTransform(input, component, false);
@@ -74,35 +52,31 @@ public class EmbedderTestCase {
void testBundledEmbedder_hosted() throws IOException, SAXException {
String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
" <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel id='minilm-l6-v2' />" +
- " <tokenizerVocab id='bert-base-uncased' />" +
+ " <transformerModel model-id='minilm-l6-v2' />" +
+ " <tokenizerVocab model-id='bert-base-uncased' />" +
" </config>" +
"</component>";
String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" +
" <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModelUrl>https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx</transformerModelUrl>" +
- " <transformerModelPath>services.xml</transformerModelPath>" +
- " <tokenizerVocabUrl>https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt</tokenizerVocabUrl>" +
- " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" +
+ " <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" +
+ " <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" +
" </config>" +
"</component>";
assertTransform(input, component, true);
}
@Test
- void testApplicationEmbedderWithBundledConfig_hosted() throws IOException, SAXException {
+ void testApplicationComponentWithModelReference_hosted() throws IOException, SAXException {
String input = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" +
" <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel id='minilm-l6-v2' />" +
- " <tokenizerVocab id='bert-base-uncased' />" +
+ " <transformerModel model-id='minilm-l6-v2' />" +
+ " <tokenizerVocab model-id='bert-base-uncased' />" +
" </config>" +
"</component>";
String component = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" +
" <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModelUrl>https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx</transformerModelUrl>" +
- " <transformerModelPath>services.xml</transformerModelPath>" +
- " <tokenizerVocabUrl>https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt</tokenizerVocabUrl>" +
- " <tokenizerVocabPath>services.xml</tokenizerVocabPath>" +
+ " <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" +
+ " <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" +
" </config>" +
"</component>";
assertTransform(input, component, true);
@@ -112,12 +86,12 @@ public class EmbedderTestCase {
void testUnknownModelId_hosted() throws IOException, SAXException {
String embedder = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "'>" +
" <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" +
- " <transformerModel id='my_model_id' />" +
- " <tokenizerVocab id='my_vocab_id' />" +
+ " <transformerModel model-id='my_model_id' />" +
+ " <tokenizerVocab model-id='my_vocab_id' />" +
" </config>" +
"</component>";
assertTransformThrows(embedder,
- "Unknown embedder model 'my_model_id'. " +
+ "Unknown model id 'my_model_id' on 'transformerModel'. " +
"Available models are [bert-base-uncased, minilm-l6-v2]",
true);
}
@@ -130,10 +104,8 @@ public class EmbedderTestCase {
Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
- assertEquals("application-url", config.getObject("transformerModelUrl").getValue());
- assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue());
- assertEquals("", config.getObject("tokenizerVocabUrl").getValue());
- assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue());
+ assertEquals("minilm-l6-v2 application-url \"\"", config.getObject("transformerModel").getValue());
+ assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue());
assertEquals("4", config.getObject("onnxIntraOpThreads").getValue());
}
@@ -145,11 +117,9 @@ public class EmbedderTestCase {
Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding"));
- assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
- config.getObject("transformerModelUrl").getValue());
- assertEquals(emptyPathFileName, config.getObject("transformerModelPath").getValue());
- assertEquals("", config.getObject("tokenizerVocabUrl").getValue());
- assertEquals("files/vocab.txt", config.getObject("tokenizerVocabPath").getValue());
+ assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"",
+ config.getObject("transformerModel").getValue());
+ assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue());
assertEquals("4", config.getObject("onnxIntraOpThreads").getValue());
}
@@ -161,10 +131,8 @@ public class EmbedderTestCase {
Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
- assertEquals("application-url", config.getObject("modelUrl").getValue());
- assertEquals(emptyPathFileName, config.getObject("modelPath").getValue());
- assertEquals("files/vocab.txt", config.getObject("vocabPath").getValue());
- assertEquals("foo", config.getObject("myValue").getValue());
+ assertEquals("minilm-l6-v2 application-url \"\"", config.getObject("model").getValue());
+ assertEquals("\"\" \"\" files/vocab.txt", config.getObject("vocab").getValue());
}
@Test
@@ -175,11 +143,9 @@ public class EmbedderTestCase {
Component<?, ?> testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer"));
ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph"));
- assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
- config.getObject("modelUrl").getValue());
- assertEquals(emptyPathFileName, config.getObject("modelPath").getValue());
- assertEquals("files/vocab.txt", config.getObject("vocabPath").getValue());
- assertEquals("foo", config.getObject("myValue").getValue());
+ assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"",
+ config.getObject("model").getValue());
+ assertEquals("\"\" \"\" files/vocab.txt", config.getObject("vocab").getValue());
}
private VespaModel loadModel(Path path, boolean hosted) throws Exception {
@@ -189,13 +155,10 @@ public class EmbedderTestCase {
return new VespaModel(state);
}
- private void assertTransform(String embedder, String component) throws IOException, SAXException {
- assertTransform(embedder, component, false);
- }
-
- private void assertTransform(String embedder, String expectedComponent, boolean hosted) throws IOException, SAXException {
- assertSpec(createElement(expectedComponent),
- ModelConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder)));
+ private void assertTransform(String inputComponent, String expectedComponent, boolean hosted) throws IOException, SAXException {
+ Element component = createElement(inputComponent);
+ ModelIdResolver.resolveModelIds(component, hosted);
+ assertSpec(createElement(expectedComponent), component);
}
private void assertSpec(Element e1, Element e2) {
@@ -209,8 +172,9 @@ public class EmbedderTestCase {
private void assertAttributes(Element e1, Element e2) {
NamedNodeMap map = e1.getAttributes();
for (int i = 0; i < map.getLength(); ++i) {
- String attr = map.item(i).getNodeName();
- assertEquals(e1.getAttribute(attr), e2.getAttribute(attr));
+ String attribute = map.item(i).getNodeName();
+ assertEquals(e1.getAttribute(attribute), e2.getAttribute(attribute),
+ "Attribute '" + attribute + "' is equal");
}
}
@@ -227,7 +191,7 @@ public class EmbedderTestCase {
private void assertTransformThrows(String embedder, String expectedMessage, boolean hosted) throws IOException, SAXException {
try {
- ModelConfigTransformer.transform(createEmptyDeployState(hosted), createElement(embedder));
+ ModelIdResolver.resolveModelIds(createElement(embedder), hosted);
fail("Expected exception was not thrown: " + expectedMessage);
} catch (IllegalArgumentException e) {
assertEquals(expectedMessage, e.getMessage());
@@ -239,9 +203,4 @@ public class EmbedderTestCase {
return (Element) doc.getFirstChild();
}
- private DeployState createEmptyDeployState(boolean hosted) {
- TestProperties properties = new TestProperties().setHostedVespa(hosted);
- return new DeployState.Builder().properties(properties).build();
- }
-
}