diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-02 12:10:32 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-02 12:10:32 +0200 |
commit | a67788f2b7786a2cfcb9244d1e72a7fb1815425b (patch) | |
tree | fa34be2f0f13ef4ea116dd12853c734de3bc2eca /config-model/src/main/java/com/yahoo/vespa/model/container/component | |
parent | e757e5ff2e6dadbe31389c7dfeb3f52827a1668b (diff) |
Introduce services.xml syntax for configuring HuggingFace embedders
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component')
3 files changed, 146 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java new file mode 100644 index 00000000000..1c36716699e --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -0,0 +1,79 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.util.Optional; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + + +/** + * @author bjorncs + */ +public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer { + private final ModelReference model; + private final ModelReference vocab; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + private final String transformerTokenTypeIds; + private final String transformerOutput; + private final Boolean normalize; + private final String onnxExecutionMode; + private final Integer onnxInteropThreads; + private final Integer onnxIntraopThreads; + private final Integer onnxGpuDevice; + + public HuggingFaceEmbedder(Element xml, DeployState state) { + super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); + boolean hosted = state.isHosted(); + var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); + model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted); + vocab = getOptionalChild(xml, "tokenizer-model") + .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted)) + .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted)); + maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null); + transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null); + transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null); + normalize = getOptionalChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); + onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null); + onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); + onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); + onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + } + + private static ModelReference resolveDefaultVocab(Element model, boolean hosted) { + if (hosted && model.hasAttribute("model-id")) { + var implicitVocabId = model.getAttribute("model-id") + "-vocab"; + return ModelIdResolver.resolveToModelReference( + "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true); + } + throw new IllegalArgumentException("'tokenizer-model' must be specified"); + } + + @Override + public void getConfig(HuggingFaceEmbedderConfig.Builder b) { + b.transformerModel(model).tokenizerPath(vocab); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (normalize != null) b.normalize(normalize); + if (onnxExecutionMode != null) b.transformerExecutionMode( + HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode)); + if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); + if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); + if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java new file mode 100644 index 00000000000..ba8521a0089 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -0,0 +1,47 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; +import com.yahoo.text.XML; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.util.Map; +import java.util.TreeMap; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME; + +/** + * @author bjorncs + */ +public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer { + + private final Map<String, ModelReference> langToModel = new TreeMap<>(); + private final Boolean specialTokens; + private final Integer maxLength; + private final Boolean truncation; + + public HuggingFaceTokenizer(Element xml, DeployState state) { + super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); + for (Element element : XML.getChildren(xml, "model")) { + var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; + langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted())); + } + specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null); + maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null); + truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null); + } + + @Override + public void getConfig(HuggingFaceTokenizerConfig.Builder builder) { + langToModel.forEach((lang, vocab) -> { + builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab)); + }); + if (specialTokens != null) builder.addSpecialTokens(specialTokens); + if (maxLength != null) builder.maxLength(maxLength); + if (truncation != null) builder.truncation(truncation); + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java new file mode 100644 index 00000000000..522c78f2f25 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java @@ -0,0 +1,20 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.osgi.provider.model.ComponentModel; +import org.w3c.dom.Element; + +/** + * @author bjorncs + */ +abstract class TypedComponent extends SimpleComponent { + + private final Element xml; + + protected TypedComponent(String className, String bundle, Element xml) { + super(new ComponentModel(xml.getAttribute("id"), className, bundle)); + this.xml = xml; + } + +} |