diff options
author | bjormel <bjormel@yahooinc.com> | 2023-10-01 12:23:12 +0000 |
---|---|---|
committer | bjormel <bjormel@yahooinc.com> | 2023-10-01 12:23:12 +0000 |
commit | e9058b555d4dfea2f6c872d9a677e8678b569569 (patch) | |
tree | fa1b67c6e39712c1e0d9f308b0dd55573b43f913 /config-model/src/main/java/com/yahoo/vespa/model/container/component | |
parent | 0ad931fa86658904fe9212b014d810236b0e00e4 (diff) | |
parent | 16030193ec04ee41e98779a3d7ee6a6c1d0d0d6f (diff) |
Merge branch 'master' into bjormel/aws-main-controller
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component')
7 files changed, 114 insertions, 53 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 205848e1b67..d02b7d0de5f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -5,10 +5,9 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import static com.yahoo.text.XML.getChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -17,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI */ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; @@ -33,10 +32,11 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf private final Integer onnxGpuDevice; - public BertEmbedder(Element xml, DeployState state) { + public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); - model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state); - vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -49,11 +49,12 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + model.registerOnnxModelCost(cluster); } @Override public void getConfig(BertBaseEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerVocab(vocab); + b.transformerModel(modelRef).tokenizerVocab(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index c0fdfe3dc64..66e3b1c9dfd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -5,13 +5,9 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.ColBertEmbedderConfig; -import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import java.util.Optional; - -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -20,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bergum */ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxQueryTokens; @@ -40,13 +36,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo private final Integer onnxIntraopThreads; private final Integer onnxGpuDevice; - public ColBertEmbedder(Element xml, DeployState state) { + public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); - vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); @@ -60,21 +56,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); - + model.registerOnnxModelCost(cluster); } - private static ModelReference resolveDefaultVocab(Element model, DeployState state) { - if (state.isHosted() && model.hasAttribute("model-id")) { - var implicitVocabId = model.getAttribute("model-id") + "-vocab"; - return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + private static ModelReference resolveDefaultVocab(Model model, DeployState state) { + var modelId = model.modelId().orElse(null); + if (state.isHosted() && modelId != null) { + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } @Override public void getConfig(ColBertEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerPath(vocab); + b.transformerModel(modelRef).tokenizerPath(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java index 31031aa5bf2..969db6553e6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java @@ -1,9 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.component; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.container.handler.threadpool.ContainerThreadpoolConfig; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.vespa.model.container.ContainerThreadpool; +import org.w3c.dom.Element; import java.util.ArrayList; import java.util.Arrays; @@ -76,8 +78,8 @@ public class Handler extends Component<Component<?, ?>, ComponentModel> { */ public static class DefaultHandlerThreadpool extends ContainerThreadpool { - public DefaultHandlerThreadpool() { - super("default-handler-common", null); + public DefaultHandlerThreadpool(DeployState ds, Element options) { + super(ds, "default-handler-common", options); } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index f4017339699..af47bee137a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -5,12 +5,9 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import java.util.Optional; - -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -19,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bjorncs */ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; @@ -33,13 +30,13 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Integer onnxGpuDevice; private final String poolingStrategy; - public HuggingFaceEmbedder(Element xml, DeployState state) { + public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); - vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -51,20 +48,20 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + model.registerOnnxModelCost(cluster); } - private static ModelReference resolveDefaultVocab(Element model, DeployState state) { - if (state.isHosted() && model.hasAttribute("model-id")) { - var implicitVocabId = model.getAttribute("model-id") + "-vocab"; - return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + private static ModelReference resolveDefaultVocab(Model model, DeployState state) { + var modelId = model.modelId().orElse(null); + if (state.isHosted() && modelId != null) { + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerPath(vocab); + b.transformerModel(modelRef).tokenizerPath(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index 0bf5491e872..e9ac93caa68 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -7,7 +7,6 @@ import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation; import com.yahoo.text.XML; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; import java.util.Map; @@ -26,7 +25,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); for (Element element : XML.getChildren(xml, "model")) { var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; - langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state)); + langToModel.put(lang, Model.fromXml(state, element).modelReference()); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java new file mode 100644 index 00000000000..76d93c38aee --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -0,0 +1,69 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.builder.xml.XmlHelper; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.path.Path; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.net.URI; +import java.util.Objects; +import java.util.Optional; + +/** + * Represents a model, e.g ONNX model for an embedder. + * + * @author bjorncs + */ +class Model { + private final String paramName; + private final String modelId; + private final URI url; + private final ApplicationFile file; + private final ModelReference ref; + + private Model(DeployState ds, String paramName, String modelId, URI url, Path file) { + this.paramName = Objects.requireNonNull(paramName); + if (modelId == null && url == null && file == null) + throw new IllegalArgumentException("At least one of 'model-id', 'url' or 'path' must be specified"); + this.modelId = modelId; + this.url = url; + this.file = file != null ? ds.getApplicationPackage().getFile(file) : null; + this.ref = ModelIdResolver.resolveToModelReference( + paramName, Optional.ofNullable(modelId), Optional.ofNullable(url).map(URI::toString), + Optional.ofNullable(file).map(Path::toString), ds); + } + + static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file) { + return new Model(ds, paramName, modelId, url, file); + } + + static Optional<Model> fromXml(DeployState ds, Element parent, String name) { + return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e)); + } + + static Model fromXml(DeployState ds, Element model) { + var modelId = XmlHelper.getOptionalAttribute(model, "model-id").orElse(null); + var url = XmlHelper.getOptionalAttribute(model, "url").map(URI::create).orElse(null); + var path = XmlHelper.getOptionalAttribute(model, "path").map(Path::fromString).orElse(null); + return new Model(ds, model.getTagName(), modelId, url, path); + } + + void registerOnnxModelCost(ApplicationContainerCluster c) { + var resolvedUrl = resolvedUrl().orElse(null); + if (file != null) c.onnxModelCost().registerModel(file); + else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl); + } + + String name() { return paramName; } + Optional<String> modelId() { return Optional.ofNullable(modelId); } + Optional<URI> url() { return Optional.ofNullable(url); } + Optional<URI> resolvedUrl() { return ref.url().map(u -> URI.create(u.value())); } + Optional<ApplicationFile> file() { return Optional.ofNullable(file); } + ModelReference modelReference() { return ref; } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java index c0431d01784..2354298779d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java @@ -29,8 +29,6 @@ public class ChainedComponent<T extends ChainedComponentModel> extends Component private ComponentId namespace() { var owner = getParent().getParent(); - return (owner instanceof Chain) ? - ((Chain) owner).getGlobalComponentId() : - null; + return (owner instanceof Chain<?> chain) ? chain.getGlobalComponentId() : null; } } |