aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 18:46:50 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 18:55:35 +0200
commit6c664b24186756021e6b39801b9694d1815311bf (patch)
tree8d9ae567404eb4329ac9c22df67970f62ae33f14 /config-model/src/main
parentb2b7293c58d60ad87e337868e3c4c7c576cc0b79 (diff)
Ensure model is registered with file registry
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java13
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java29
4 files changed, 23 insertions, 25 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index 980dbcf0a76..56aa974da48 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -35,8 +35,8 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
public BertEmbedder(Element xml, DeployState state) {
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state.isHosted());
- vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state.isHosted());
+ model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state);
+ vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state);
maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index bb26b7e4fd7..6e7a1cc31dd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -35,12 +35,11 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
public HuggingFaceEmbedder(Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- boolean hosted = state.isHosted();
var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
- model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted);
+ model = ModelIdResolver.resolveToModelReference(transformerModelElem, state);
vocab = getOptionalChild(xml, "tokenizer-model")
- .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted))
- .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted));
+ .map(elem -> ModelIdResolver.resolveToModelReference(elem, state))
+ .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state));
maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
@@ -54,11 +53,11 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
}
- private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
- if (hosted && model.hasAttribute("model-id")) {
+ private static ModelReference resolveDefaultVocab(Element model, DeployState state) {
+ if (state.isHosted() && model.hasAttribute("model-id")) {
var implicitVocabId = model.getAttribute("model-id") + "-vocab";
return ModelIdResolver.resolveToModelReference(
- "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true);
+ "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state);
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index ba8521a0089..966dbe8260a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -28,7 +28,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
for (Element element : XML.getChildren(xml, "model")) {
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
- langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted()));
+ langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state));
}
specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
index c0f49f3148d..96f653bf793 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -1,10 +1,10 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.xml;
-import com.yahoo.config.FileReference;
import com.yahoo.config.ModelReference;
import com.yahoo.config.UrlReference;
import com.yahoo.config.model.builder.xml.XmlHelper;
+import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.text.XML;
import org.w3c.dom.Element;
@@ -80,25 +80,24 @@ public class ModelIdResolver {
}
- public static ModelReference resolveToModelReference(Element elem, boolean hosted) {
+ public static ModelReference resolveToModelReference(Element elem, DeployState state) {
return resolveToModelReference(
elem.getTagName(), XmlHelper.getOptionalAttribute(elem, "model-id"),
- XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), hosted);
+ XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), state);
}
public static ModelReference resolveToModelReference(
- String paramName, Optional<String> id, Optional<String> url, Optional<String> path, boolean hosted) {
- if (id.isEmpty()) return ModelReference.unresolved(
- Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new));
- else if (hosted) {
- return ModelReference.unresolved(
- id, Optional.of(UrlReference.valueOf(modelIdToUrl(paramName, id.get()))), Optional.empty());
- } else if (url.isEmpty() && path.isEmpty()) {
- throw onlyModelIdInHostedException(paramName);
- } else {
- return ModelReference.unresolved(
- Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new));
- }
+ String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) {
+ if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state);
+ else if (state.isHosted())
+ return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get())), Optional.empty(), state);
+ else if (url.isEmpty() && path.isEmpty()) throw onlyModelIdInHostedException(paramName);
+ else return createModelReference(id, url, path, state);
+ }
+
+ private static ModelReference createModelReference(Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) {
+ var fileRef = path.map(p -> state.getFileRegistry().addFile(p));
+ return ModelReference.unresolved(id, url.map(UrlReference::valueOf), fileRef);
}
private static IllegalArgumentException onlyModelIdInHostedException(String paramName) {