diff options
8 files changed, 112 insertions, 58 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index 595cd97e6b6..33ed55ecaef 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -7,6 +7,8 @@ import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; +import java.net.URI; + /** * @author bjorncs */ @@ -17,14 +19,16 @@ public interface OnnxModelCost { interface Calculator { long aggregatedModelCostInBytes(); void registerModel(ApplicationFile path); - void registerModel(ModelReference ref); + @Deprecated(forRemoval = true) void registerModel(ModelReference ref); // TODO(bjorncs): remove once no longer in use by old config models + void registerModel(URI uri); } static OnnxModelCost disabled() { return (__, ___) -> new Calculator() { @Override public long aggregatedModelCostInBytes() { return 0; } @Override public void registerModel(ApplicationFile path) {} - @Override public void registerModel(ModelReference ref) {} + @SuppressWarnings("removal") @Override public void registerModel(ModelReference ref) {} + @Override public void registerModel(URI uri) {} }; } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 76bb1a9e02a..d02b7d0de5f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -6,10 +6,8 @@ import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; -import static com.yahoo.text.XML.getChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -18,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI */ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; @@ -36,8 +34,9 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); - model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state); - vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -50,12 +49,12 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); - cluster.onnxModelCost().registerModel(model); + model.registerOnnxModelCost(cluster); } @Override public void getConfig(BertBaseEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerVocab(vocab); + b.transformerModel(modelRef).tokenizerVocab(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index 63096ebcbe2..66e3b1c9dfd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -6,12 +6,8 @@ import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; -import java.util.Optional; - -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -20,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bergum */ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxQueryTokens; @@ -42,11 +38,11 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); - vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); @@ -60,21 +56,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); - cluster.onnxModelCost().registerModel(model); + model.registerOnnxModelCost(cluster); } - private static ModelReference resolveDefaultVocab(Element model, DeployState state) { - if (state.isHosted() && model.hasAttribute("model-id")) { - var implicitVocabId = model.getAttribute("model-id") + "-vocab"; - return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + private static ModelReference resolveDefaultVocab(Model model, DeployState state) { + var modelId = model.modelId().orElse(null); + if (state.isHosted() && modelId != null) { + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } @Override public void getConfig(ColBertEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerPath(vocab); + b.transformerModel(modelRef).tokenizerPath(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 41b80bf1cb2..af47bee137a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -6,12 +6,8 @@ import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; -import java.util.Optional; - -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -20,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bjorncs */ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; @@ -36,11 +32,11 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); - vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -52,21 +48,20 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - cluster.onnxModelCost().registerModel(model); + model.registerOnnxModelCost(cluster); } - private static ModelReference resolveDefaultVocab(Element model, DeployState state) { - if (state.isHosted() && model.hasAttribute("model-id")) { - var implicitVocabId = model.getAttribute("model-id") + "-vocab"; - return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + private static ModelReference resolveDefaultVocab(Model model, DeployState state) { + var modelId = model.modelId().orElse(null); + if (state.isHosted() && modelId != null) { + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerPath(vocab); + b.transformerModel(modelRef).tokenizerPath(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index 0bf5491e872..e9ac93caa68 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -7,7 +7,6 @@ import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation; import com.yahoo.text.XML; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; import java.util.Map; @@ -26,7 +25,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); for (Element element : XML.getChildren(xml, "model")) { var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; - langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state)); + langToModel.put(lang, Model.fromXml(state, element).modelReference()); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java new file mode 100644 index 00000000000..084258ce1a6 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -0,0 +1,67 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.builder.xml.XmlHelper; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.path.Path; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.net.URI; +import java.util.Objects; +import java.util.Optional; + +/** + * Represents a model, e.g ONNX model for an embedder. + * + * @author bjorncs + */ +class Model { + private final String paramName; + private final String modelId; + private final URI url; + private final ApplicationFile file; + private final ModelReference ref; + + private Model(DeployState ds, String paramName, String modelId, URI url, Path file) { + this.paramName = Objects.requireNonNull(paramName); + if (modelId == null && url == null && file == null) + throw new IllegalArgumentException("At least one of 'model-id', 'url' or 'path' must be specified"); + this.modelId = modelId; + this.url = url; + this.file = file != null ? ds.getApplicationPackage().getFile(file) : null; + this.ref = ModelIdResolver.resolveToModelReference( + paramName, Optional.ofNullable(modelId), Optional.ofNullable(url).map(URI::toString), + Optional.ofNullable(file).map(Path::toString), ds); + } + + static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file) { + return new Model(ds, paramName, modelId, url, file); + } + + static Optional<Model> fromXml(DeployState ds, Element parent, String name) { + return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e)); + } + + static Model fromXml(DeployState ds, Element model) { + var modelId = XmlHelper.getOptionalAttribute(model, "model-id").orElse(null); + var url = XmlHelper.getOptionalAttribute(model, "url").map(URI::create).orElse(null); + var path = XmlHelper.getOptionalAttribute(model, "path").map(Path::fromString).orElse(null); + return new Model(ds, model.getTagName(), modelId, url, path); + } + + void registerOnnxModelCost(ApplicationContainerCluster c) { + if (file != null) c.onnxModelCost().registerModel(file); + else c.onnxModelCost().registerModel(url); + } + + String name() { return paramName; } + Optional<String> modelId() { return Optional.ofNullable(modelId); } + Optional<URI> url() { return Optional.ofNullable(url); } + Optional<ApplicationFile> file() { return Optional.ofNullable(file); } + ModelReference modelReference() { return ref; } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index be3ca0b8aa9..14216dd8855 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -3,7 +3,6 @@ package com.yahoo.vespa.model.container.xml; import com.yahoo.config.ModelReference; import com.yahoo.config.UrlReference; -import com.yahoo.config.model.builder.xml.XmlHelper; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.text.XML; import org.w3c.dom.Element; @@ -88,13 +87,6 @@ public class ModelIdResolver { } } - - public static ModelReference resolveToModelReference(Element elem, DeployState state) { - return resolveToModelReference( - elem.getTagName(), XmlHelper.getOptionalAttribute(elem, "model-id"), - XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), state); - } - public static ModelReference resolveToModelReference( String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) { if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 447614b8396..9b3b659c252 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.xml.sax.SAXException; import java.io.IOException; +import java.net.URI; import java.util.concurrent.atomic.AtomicLong; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -117,9 +118,11 @@ class JvmHeapSizeValidatorTest { @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} + @SuppressWarnings("removal") @Override public void registerModel(ModelReference ref) {} + @Override - public void registerModel(ModelReference ref) { - assertEquals("https://my/url/model.onnx", ref.url().orElseThrow().value().toString()); + public void registerModel(URI uri) { + assertEquals("https://my/url/model.onnx", uri.toString()); totalCost.addAndGet(modelCost); } } |