diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java | 33 |
1 files changed, 15 insertions, 18 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index f4017339699..af47bee137a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -5,12 +5,9 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; -import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; -import java.util.Optional; - -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; @@ -19,8 +16,8 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI * @author bjorncs */ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer { - private final ModelReference model; - private final ModelReference vocab; + private final ModelReference modelRef; + private final ModelReference vocabRef; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; @@ -33,13 +30,13 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Integer onnxGpuDevice; private final String poolingStrategy; - public HuggingFaceEmbedder(Element xml, DeployState state) { + public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); - vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + modelRef = model.modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-model") + .map(Model::modelReference) + .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -51,20 +48,20 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); + model.registerOnnxModelCost(cluster); } - private static ModelReference resolveDefaultVocab(Element model, DeployState state) { - if (state.isHosted() && model.hasAttribute("model-id")) { - var implicitVocabId = model.getAttribute("model-id") + "-vocab"; - return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + private static ModelReference resolveDefaultVocab(Model model, DeployState state) { + var modelId = model.modelId().orElse(null); + if (state.isHosted() && modelId != null) { + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { - b.transformerModel(model).tokenizerPath(vocab); + b.transformerModel(modelRef).tokenizerPath(vocabRef); if (maxTokens != null) b.transformerMaxTokens(maxTokens); if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); |