From a3321b37773fe969652d37e5b3b26d07bfddd259 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Mon, 5 Jun 2023 15:32:23 +0200 Subject: Add typed component definition for Bert embedder --- .../model/builder/xml/dom/DomComponentBuilder.java | 2 + .../model/container/component/BertEmbedder.java | 70 ++++++++++++++++++++++ config-model/src/main/resources/schema/common.rnc | 29 +++++++-- 3 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java (limited to 'config-model') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java index fa0ee3f9857..d0e1ede2cfa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java @@ -7,6 +7,7 @@ import com.yahoo.config.model.producer.AnyConfigProducer; import com.yahoo.config.model.producer.TreeConfigProducer; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; +import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; @@ -44,6 +45,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde return switch (type) { case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state); case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state); + case "bert-embedder" -> new BertEmbedder(spec, state); default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type)); }; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java new file mode 100644 index 00000000000..980dbcf0a76 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -0,0 +1,70 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue; +import static com.yahoo.text.XML.getChild; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + +/** + * @author bjorncs + */ +public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer { + + private final ModelReference model; + private final ModelReference vocab; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + private final String transformerTokenTypeIds; + private final String transformerOutput; + private final Integer tranformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final String poolingStrategy; + private final String onnxExecutionMode; + private final Integer onnxInteropThreads; + private final Integer onnxIntraopThreads; + private final Integer onnxGpuDevice; + + + public BertEmbedder(Element xml, DeployState state) { + super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); + model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state.isHosted()); + vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state.isHosted()); + maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null); + transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null); + transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null); + tranformerStartSequenceToken = getOptionalChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getOptionalChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); + onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null); + onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); + onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); + onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + } + + @Override + public void getConfig(BertBaseEmbedderConfig.Builder b) { + b.transformerModel(model).tokenizerVocab(vocab); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (tranformerStartSequenceToken != null) b.transformerStartSequenceToken(tranformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (poolingStrategy != null) b.poolingStrategy(BertBaseEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); + if (onnxExecutionMode != null) b.onnxExecutionMode(BertBaseEmbedderConfig.OnnxExecutionMode.Enum.valueOf(onnxExecutionMode)); + if (onnxInteropThreads != null) b.onnxInterOpThreads(onnxInteropThreads); + if (onnxIntraopThreads != null) b.onnxIntraOpThreads(onnxIntraopThreads); + if (onnxGpuDevice != null) b.onnxGpuDevice(onnxGpuDevice); + } +} diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 4e7cb526efb..b2b71950a0c 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -80,7 +80,7 @@ ComponentDefinition = TypedComponentDefinition = attribute id { xsd:Name } & - (HuggingFaceEmbedder | HuggingFaceTokenizer) & + (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) & GenericConfig* & Component* @@ -94,14 +94,31 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & - element onnx-execution-mode { "parallel" | "sequential" }? & - element onnx-interop-threads { xsd:integer }? & - element onnx-intraop-threads { xsd:integer }? & - element onnx-gpu-device { xsd:integer }? + OnnxModelExecutionParams HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & element model { attribute language { xsd:string }? & ModelReference }+ & element special-tokens { xsd:boolean }? & element max-length { xsd:integer }? & - element truncation { xsd:boolean }? \ No newline at end of file + element truncation { xsd:boolean }? + +BertBaseEmbedder = + attribute type { "bert-embedder" } & + element transformer-model { ModelReference } & + element tokenizer-vocab { ModelReference } & + element max-tokens { xsd:nonNegativeInteger }? & + element pooling-strategy { "cls" | "mean" }? & + element transformer-input-ids { xsd:string }? & + element transformer-attention-mask { xsd:string }? & + element transformer-token-type-ids { xsd:string }? & + element transformer-output { xsd:string }? & + element transformer-start-sequence-token { xsd:integer }? & + element transformer-end-sequence-token { xsd:integer }? & + OnnxModelExecutionParams + +OnnxModelExecutionParams = + element onnx-execution-mode { "parallel" | "sequential" }? & + element onnx-interop-threads { xsd:integer }? & + element onnx-intraop-threads { xsd:integer }? & + element onnx-gpu-device { xsd:integer }? \ No newline at end of file -- cgit v1.2.3