aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/container/component
diff options
context:
space:
mode:
authorbjormel <bjormel@yahooinc.com>2023-10-01 12:23:12 +0000
committerbjormel <bjormel@yahooinc.com>2023-10-01 12:23:12 +0000
commite9058b555d4dfea2f6c872d9a677e8678b569569 (patch)
treefa1b67c6e39712c1e0d9f308b0dd55573b43f913 /config-model/src/main/java/com/yahoo/vespa/model/container/component
parent0ad931fa86658904fe9212b014d810236b0e00e4 (diff)
parent16030193ec04ee41e98779a3d7ee6a6c1d0d0d6f (diff)
Merge branch 'master' into bjormel/aws-main-controller
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java17
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java35
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java6
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java33
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java69
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java4
7 files changed, 114 insertions, 53 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index 205848e1b67..d02b7d0de5f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -5,10 +5,9 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.BertBaseEmbedderConfig;
-import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
-import static com.yahoo.text.XML.getChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -17,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
*/
public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer {
- private final ModelReference model;
- private final ModelReference vocab;
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
private final Integer maxTokens;
private final String transformerInputIds;
private final String transformerAttentionMask;
@@ -33,10 +32,11 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
private final Integer onnxGpuDevice;
- public BertEmbedder(Element xml, DeployState state) {
+ public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state);
- vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state);
+ var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference();
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
@@ -49,11 +49,12 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ model.registerOnnxModelCost(cluster);
}
@Override
public void getConfig(BertBaseEmbedderConfig.Builder b) {
- b.transformerModel(model).tokenizerVocab(vocab);
+ b.transformerModel(modelRef).tokenizerVocab(vocabRef);
if (maxTokens != null) b.transformerMaxTokens(maxTokens);
if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index c0fdfe3dc64..66e3b1c9dfd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -5,13 +5,9 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.ColBertEmbedderConfig;
-import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
-import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
-import java.util.Optional;
-
-import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -20,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
* @author bergum
*/
public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer {
- private final ModelReference model;
- private final ModelReference vocab;
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
private final Integer maxQueryTokens;
@@ -40,13 +36,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
private final Integer onnxIntraopThreads;
private final Integer onnxGpuDevice;
- public ColBertEmbedder(Element xml, DeployState state) {
+ public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
- model = ModelIdResolver.resolveToModelReference(transformerModelElem, state);
- vocab = getOptionalChild(xml, "tokenizer-model")
- .map(elem -> ModelIdResolver.resolveToModelReference(elem, state))
- .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state));
+ var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state));
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null);
maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null);
@@ -60,21 +56,20 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
-
+ model.registerOnnxModelCost(cluster);
}
- private static ModelReference resolveDefaultVocab(Element model, DeployState state) {
- if (state.isHosted() && model.hasAttribute("model-id")) {
- var implicitVocabId = model.getAttribute("model-id") + "-vocab";
- return ModelIdResolver.resolveToModelReference(
- "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state);
+ private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
+ var modelId = model.modelId().orElse(null);
+ if (state.isHosted() && modelId != null) {
+ return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
@Override
public void getConfig(ColBertEmbedderConfig.Builder b) {
- b.transformerModel(model).tokenizerPath(vocab);
+ b.transformerModel(modelRef).tokenizerPath(vocabRef);
if (maxTokens != null) b.transformerMaxTokens(maxTokens);
if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java
index 31031aa5bf2..969db6553e6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java
@@ -1,9 +1,11 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.component;
+import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.container.handler.threadpool.ContainerThreadpoolConfig;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.vespa.model.container.ContainerThreadpool;
+import org.w3c.dom.Element;
import java.util.ArrayList;
import java.util.Arrays;
@@ -76,8 +78,8 @@ public class Handler extends Component<Component<?, ?>, ComponentModel> {
*/
public static class DefaultHandlerThreadpool extends ContainerThreadpool {
- public DefaultHandlerThreadpool() {
- super("default-handler-common", null);
+ public DefaultHandlerThreadpool(DeployState ds, Element options) {
+ super(ds, "default-handler-common", options);
}
@Override
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index f4017339699..af47bee137a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -5,12 +5,9 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
-import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
-import java.util.Optional;
-
-import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
@@ -19,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
* @author bjorncs
*/
public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer {
- private final ModelReference model;
- private final ModelReference vocab;
+ private final ModelReference modelRef;
+ private final ModelReference vocabRef;
private final Integer maxTokens;
private final String transformerInputIds;
private final String transformerAttentionMask;
@@ -33,13 +30,13 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private final Integer onnxGpuDevice;
private final String poolingStrategy;
- public HuggingFaceEmbedder(Element xml, DeployState state) {
+ public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
- model = ModelIdResolver.resolveToModelReference(transformerModelElem, state);
- vocab = getOptionalChild(xml, "tokenizer-model")
- .map(elem -> ModelIdResolver.resolveToModelReference(elem, state))
- .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state));
+ var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ modelRef = model.modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ .map(Model::modelReference)
+ .orElseGet(() -> resolveDefaultVocab(model, state));
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
@@ -51,20 +48,20 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
+ model.registerOnnxModelCost(cluster);
}
- private static ModelReference resolveDefaultVocab(Element model, DeployState state) {
- if (state.isHosted() && model.hasAttribute("model-id")) {
- var implicitVocabId = model.getAttribute("model-id") + "-vocab";
- return ModelIdResolver.resolveToModelReference(
- "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state);
+ private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
+ var modelId = model.modelId().orElse(null);
+ if (state.isHosted() && modelId != null) {
+ return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
@Override
public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
- b.transformerModel(model).tokenizerPath(vocab);
+ b.transformerModel(modelRef).tokenizerPath(vocabRef);
if (maxTokens != null) b.transformerMaxTokens(maxTokens);
if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index 0bf5491e872..e9ac93caa68 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -7,7 +7,6 @@ import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation;
import com.yahoo.text.XML;
-import com.yahoo.vespa.model.container.xml.ModelIdResolver;
import org.w3c.dom.Element;
import java.util.Map;
@@ -26,7 +25,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
for (Element element : XML.getChildren(xml, "model")) {
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
- langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state));
+ langToModel.put(lang, Model.fromXml(state, element).modelReference());
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
new file mode 100644
index 00000000000..76d93c38aee
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
@@ -0,0 +1,69 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.model.builder.xml.XmlHelper;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.path.Path;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import java.net.URI;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * Represents a model, e.g ONNX model for an embedder.
+ *
+ * @author bjorncs
+ */
+class Model {
+ private final String paramName;
+ private final String modelId;
+ private final URI url;
+ private final ApplicationFile file;
+ private final ModelReference ref;
+
+ private Model(DeployState ds, String paramName, String modelId, URI url, Path file) {
+ this.paramName = Objects.requireNonNull(paramName);
+ if (modelId == null && url == null && file == null)
+ throw new IllegalArgumentException("At least one of 'model-id', 'url' or 'path' must be specified");
+ this.modelId = modelId;
+ this.url = url;
+ this.file = file != null ? ds.getApplicationPackage().getFile(file) : null;
+ this.ref = ModelIdResolver.resolveToModelReference(
+ paramName, Optional.ofNullable(modelId), Optional.ofNullable(url).map(URI::toString),
+ Optional.ofNullable(file).map(Path::toString), ds);
+ }
+
+ static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file) {
+ return new Model(ds, paramName, modelId, url, file);
+ }
+
+ static Optional<Model> fromXml(DeployState ds, Element parent, String name) {
+ return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e));
+ }
+
+ static Model fromXml(DeployState ds, Element model) {
+ var modelId = XmlHelper.getOptionalAttribute(model, "model-id").orElse(null);
+ var url = XmlHelper.getOptionalAttribute(model, "url").map(URI::create).orElse(null);
+ var path = XmlHelper.getOptionalAttribute(model, "path").map(Path::fromString).orElse(null);
+ return new Model(ds, model.getTagName(), modelId, url, path);
+ }
+
+ void registerOnnxModelCost(ApplicationContainerCluster c) {
+ var resolvedUrl = resolvedUrl().orElse(null);
+ if (file != null) c.onnxModelCost().registerModel(file);
+ else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl);
+ }
+
+ String name() { return paramName; }
+ Optional<String> modelId() { return Optional.ofNullable(modelId); }
+ Optional<URI> url() { return Optional.ofNullable(url); }
+ Optional<URI> resolvedUrl() { return ref.url().map(u -> URI.create(u.value())); }
+ Optional<ApplicationFile> file() { return Optional.ofNullable(file); }
+ ModelReference modelReference() { return ref; }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java
index c0431d01784..2354298779d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/chain/ChainedComponent.java
@@ -29,8 +29,6 @@ public class ChainedComponent<T extends ChainedComponentModel> extends Component
private ComponentId namespace() {
var owner = getParent().getParent();
- return (owner instanceof Chain) ?
- ((Chain) owner).getGlobalComponentId() :
- null;
+ return (owner instanceof Chain<?> chain) ? chain.getGlobalComponentId() : null;
}
}