aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/container/component
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-02 12:10:32 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-02 12:10:32 +0200
commita67788f2b7786a2cfcb9244d1e72a7fb1815425b (patch)
treefa34be2f0f13ef4ea116dd12853c734de3bc2eca /config-model/src/main/java/com/yahoo/vespa/model/container/component
parente757e5ff2e6dadbe31389c7dfeb3f52827a1668b (diff)
Introduce services.xml syntax for configuring HuggingFace embedders
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java79
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java47
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java20
3 files changed, 146 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
new file mode 100644
index 00000000000..1c36716699e
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -0,0 +1,79 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import java.util.Optional;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild;
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+
+
+/**
+ * @author bjorncs
+ */
+public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer {
+ private final ModelReference model;
+ private final ModelReference vocab;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+ private final Boolean normalize;
+ private final String onnxExecutionMode;
+ private final Integer onnxInteropThreads;
+ private final Integer onnxIntraopThreads;
+ private final Integer onnxGpuDevice;
+
+ public HuggingFaceEmbedder(Element xml, DeployState state) {
+ super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
+ boolean hosted = state.isHosted();
+ var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
+ model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted);
+ vocab = getOptionalChild(xml, "tokenizer-model")
+ .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted))
+ .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted));
+ maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null);
+ normalize = getOptionalChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
+ onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null);
+ onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
+ onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
+ onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ }
+
+ private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
+ if (hosted && model.hasAttribute("model-id")) {
+ var implicitVocabId = model.getAttribute("model-id") + "-vocab";
+ return ModelIdResolver.resolveToModelReference(
+ "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true);
+ }
+ throw new IllegalArgumentException("'tokenizer-model' must be specified");
+ }
+
+ @Override
+ public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
+ b.transformerModel(model).tokenizerPath(vocab);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (normalize != null) b.normalize(normalize);
+ if (onnxExecutionMode != null) b.transformerExecutionMode(
+ HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode));
+ if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
+ if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
+ if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
new file mode 100644
index 00000000000..ba8521a0089
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -0,0 +1,47 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
+import com.yahoo.text.XML;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import java.util.Map;
+import java.util.TreeMap;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
+
+/**
+ * @author bjorncs
+ */
+public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer {
+
+ private final Map<String, ModelReference> langToModel = new TreeMap<>();
+ private final Boolean specialTokens;
+ private final Integer maxLength;
+ private final Boolean truncation;
+
+ public HuggingFaceTokenizer(Element xml, DeployState state) {
+ super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
+ for (Element element : XML.getChildren(xml, "model")) {
+ var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
+ langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted()));
+ }
+ specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
+ maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
+ truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null);
+ }
+
+ @Override
+ public void getConfig(HuggingFaceTokenizerConfig.Builder builder) {
+ langToModel.forEach((lang, vocab) -> {
+ builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab));
+ });
+ if (specialTokens != null) builder.addSpecialTokens(specialTokens);
+ if (maxLength != null) builder.maxLength(maxLength);
+ if (truncation != null) builder.truncation(truncation);
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java
new file mode 100644
index 00000000000..522c78f2f25
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java
@@ -0,0 +1,20 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.osgi.provider.model.ComponentModel;
+import org.w3c.dom.Element;
+
+/**
+ * @author bjorncs
+ */
+abstract class TypedComponent extends SimpleComponent {
+
+ private final Element xml;
+
+ protected TypedComponent(String className, String bundle, Element xml) {
+ super(new ComponentModel(xml.getAttribute("id"), className, bundle));
+ this.xml = xml;
+ }
+
+}