diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-05 18:46:50 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-05 18:55:35 +0200 |
commit | 6c664b24186756021e6b39801b9694d1815311bf (patch) | |
tree | 8d9ae567404eb4329ac9c22df67970f62ae33f14 /config-model/src/main | |
parent | b2b7293c58d60ad87e337868e3c4c7c576cc0b79 (diff) |
Ensure model is registered with file registry
Diffstat (limited to 'config-model/src/main')
4 files changed, 23 insertions, 25 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 980dbcf0a76..56aa974da48 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -35,8 +35,8 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf public BertEmbedder(Element xml, DeployState state) { super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); - model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state.isHosted()); - vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state.isHosted()); + model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state); + vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state); maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index bb26b7e4fd7..6e7a1cc31dd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -35,12 +35,11 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm public HuggingFaceEmbedder(Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - boolean hosted = state.isHosted(); var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted); + model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted)); + .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) + .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null); @@ -54,11 +53,11 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); } - private static ModelReference resolveDefaultVocab(Element model, boolean hosted) { - if (hosted && model.hasAttribute("model-id")) { + private static ModelReference resolveDefaultVocab(Element model, DeployState state) { + if (state.isHosted() && model.hasAttribute("model-id")) { var implicitVocabId = model.getAttribute("model-id") + "-vocab"; return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true); + "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index ba8521a0089..966dbe8260a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -28,7 +28,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); for (Element element : XML.getChildren(xml, "model")) { var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; - langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted())); + langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state)); } specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null); maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index c0f49f3148d..96f653bf793 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -1,10 +1,10 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.xml; -import com.yahoo.config.FileReference; import com.yahoo.config.ModelReference; import com.yahoo.config.UrlReference; import com.yahoo.config.model.builder.xml.XmlHelper; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.text.XML; import org.w3c.dom.Element; @@ -80,25 +80,24 @@ public class ModelIdResolver { } - public static ModelReference resolveToModelReference(Element elem, boolean hosted) { + public static ModelReference resolveToModelReference(Element elem, DeployState state) { return resolveToModelReference( elem.getTagName(), XmlHelper.getOptionalAttribute(elem, "model-id"), - XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), hosted); + XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), state); } public static ModelReference resolveToModelReference( - String paramName, Optional<String> id, Optional<String> url, Optional<String> path, boolean hosted) { - if (id.isEmpty()) return ModelReference.unresolved( - Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new)); - else if (hosted) { - return ModelReference.unresolved( - id, Optional.of(UrlReference.valueOf(modelIdToUrl(paramName, id.get()))), Optional.empty()); - } else if (url.isEmpty() && path.isEmpty()) { - throw onlyModelIdInHostedException(paramName); - } else { - return ModelReference.unresolved( - Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new)); - } + String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) { + if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state); + else if (state.isHosted()) + return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get())), Optional.empty(), state); + else if (url.isEmpty() && path.isEmpty()) throw onlyModelIdInHostedException(paramName); + else return createModelReference(id, url, path, state); + } + + private static ModelReference createModelReference(Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) { + var fileRef = path.map(p -> state.getFileRegistry().addFile(p)); + return ModelReference.unresolved(id, url.map(UrlReference::valueOf), fileRef); } private static IllegalArgumentException onlyModelIdInHostedException(String paramName) { |