summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 15:32:23 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-05 17:50:41 +0200
commita3321b37773fe969652d37e5b3b26d07bfddd259 (patch)
tree5264601e3dbd9f7996891a275f99e04fc425af43 /config-model
parentf944b96338725a0a75bbe52922f98f9342abcdd4 (diff)
Add typed component definition for Bert embedder
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java70
-rw-r--r--config-model/src/main/resources/schema/common.rnc29
3 files changed, 95 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
index fa0ee3f9857..d0e1ede2cfa 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
@@ -7,6 +7,7 @@ import com.yahoo.config.model.producer.AnyConfigProducer;
import com.yahoo.config.model.producer.TreeConfigProducer;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.text.XML;
+import com.yahoo.vespa.model.container.component.BertEmbedder;
import com.yahoo.vespa.model.container.component.Component;
import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
@@ -44,6 +45,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde
return switch (type) {
case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state);
case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state);
+ case "bert-embedder" -> new BertEmbedder(spec, state);
default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type));
};
} else {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
new file mode 100644
index 00000000000..980dbcf0a76
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -0,0 +1,70 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.embedding.BertBaseEmbedderConfig;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
+import static com.yahoo.text.XML.getChild;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+
+/**
+ * @author bjorncs
+ */
+public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer {
+
+ private final ModelReference model;
+ private final ModelReference vocab;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+ private final Integer tranformerStartSequenceToken;
+ private final Integer transformerEndSequenceToken;
+ private final String poolingStrategy;
+ private final String onnxExecutionMode;
+ private final Integer onnxInteropThreads;
+ private final Integer onnxIntraopThreads;
+ private final Integer onnxGpuDevice;
+
+
+ public BertEmbedder(Element xml, DeployState state) {
+ super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
+ model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state.isHosted());
+ vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state.isHosted());
+ maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null);
+ tranformerStartSequenceToken = getOptionalChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
+ transformerEndSequenceToken = getOptionalChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
+ poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null);
+ onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null);
+ onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
+ onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
+ onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ }
+
+ @Override
+ public void getConfig(BertBaseEmbedderConfig.Builder b) {
+ b.transformerModel(model).tokenizerVocab(vocab);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (tranformerStartSequenceToken != null) b.transformerStartSequenceToken(tranformerStartSequenceToken);
+ if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
+ if (poolingStrategy != null) b.poolingStrategy(BertBaseEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy));
+ if (onnxExecutionMode != null) b.onnxExecutionMode(BertBaseEmbedderConfig.OnnxExecutionMode.Enum.valueOf(onnxExecutionMode));
+ if (onnxInteropThreads != null) b.onnxInterOpThreads(onnxInteropThreads);
+ if (onnxIntraopThreads != null) b.onnxIntraOpThreads(onnxIntraopThreads);
+ if (onnxGpuDevice != null) b.onnxGpuDevice(onnxGpuDevice);
+ }
+}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 4e7cb526efb..b2b71950a0c 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -80,7 +80,7 @@ ComponentDefinition =
TypedComponentDefinition =
attribute id { xsd:Name } &
- (HuggingFaceEmbedder | HuggingFaceTokenizer) &
+ (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) &
GenericConfig* &
Component*
@@ -94,14 +94,31 @@ HuggingFaceEmbedder =
element transformer-token-type-ids { xsd:string }? &
element transformer-output { xsd:string }? &
element normalize { xsd:boolean }? &
- element onnx-execution-mode { "parallel" | "sequential" }? &
- element onnx-interop-threads { xsd:integer }? &
- element onnx-intraop-threads { xsd:integer }? &
- element onnx-gpu-device { xsd:integer }?
+ OnnxModelExecutionParams
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
element model { attribute language { xsd:string }? & ModelReference }+ &
element special-tokens { xsd:boolean }? &
element max-length { xsd:integer }? &
- element truncation { xsd:boolean }? \ No newline at end of file
+ element truncation { xsd:boolean }?
+
+BertBaseEmbedder =
+ attribute type { "bert-embedder" } &
+ element transformer-model { ModelReference } &
+ element tokenizer-vocab { ModelReference } &
+ element max-tokens { xsd:nonNegativeInteger }? &
+ element pooling-strategy { "cls" | "mean" }? &
+ element transformer-input-ids { xsd:string }? &
+ element transformer-attention-mask { xsd:string }? &
+ element transformer-token-type-ids { xsd:string }? &
+ element transformer-output { xsd:string }? &
+ element transformer-start-sequence-token { xsd:integer }? &
+ element transformer-end-sequence-token { xsd:integer }? &
+ OnnxModelExecutionParams
+
+OnnxModelExecutionParams =
+ element onnx-execution-mode { "parallel" | "sequential" }? &
+ element onnx-interop-threads { xsd:integer }? &
+ element onnx-intraop-threads { xsd:integer }? &
+ element onnx-gpu-device { xsd:integer }? \ No newline at end of file