diff options
18 files changed, 346 insertions, 211 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java index fa0ee3f9857..d0e1ede2cfa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java @@ -7,6 +7,7 @@ import com.yahoo.config.model.producer.AnyConfigProducer; import com.yahoo.config.model.producer.TreeConfigProducer; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; +import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; @@ -44,6 +45,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde return switch (type) { case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state); case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state); + case "bert-embedder" -> new BertEmbedder(spec, state); default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type)); }; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java new file mode 100644 index 00000000000..56aa974da48 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -0,0 +1,70 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue; +import static com.yahoo.text.XML.getChild; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + +/** + * @author bjorncs + */ +public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConfig.Producer { + + private final ModelReference model; + private final ModelReference vocab; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + private final String transformerTokenTypeIds; + private final String transformerOutput; + private final Integer tranformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final String poolingStrategy; + private final String onnxExecutionMode; + private final Integer onnxInteropThreads; + private final Integer onnxIntraopThreads; + private final Integer onnxGpuDevice; + + + public BertEmbedder(Element xml, DeployState state) { + super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); + model = ModelIdResolver.resolveToModelReference(getChild(xml, "transformer-model"), state); + vocab = ModelIdResolver.resolveToModelReference(getChild(xml, "tokenizer-vocab"), state); + maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null); + transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null); + transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null); + tranformerStartSequenceToken = getOptionalChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getOptionalChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); + onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null); + onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); + onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); + onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + } + + @Override + public void getConfig(BertBaseEmbedderConfig.Builder b) { + b.transformerModel(model).tokenizerVocab(vocab); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (tranformerStartSequenceToken != null) b.transformerStartSequenceToken(tranformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (poolingStrategy != null) b.poolingStrategy(BertBaseEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); + if (onnxExecutionMode != null) b.onnxExecutionMode(BertBaseEmbedderConfig.OnnxExecutionMode.Enum.valueOf(onnxExecutionMode)); + if (onnxInteropThreads != null) b.onnxInterOpThreads(onnxInteropThreads); + if (onnxIntraopThreads != null) b.onnxIntraOpThreads(onnxIntraopThreads); + if (onnxGpuDevice != null) b.onnxGpuDevice(onnxGpuDevice); + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index 1c36716699e..6e7a1cc31dd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -31,15 +31,15 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private final Integer onnxInteropThreads; private final Integer onnxIntraopThreads; private final Integer onnxGpuDevice; + private final String poolingStrategy; public HuggingFaceEmbedder(Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - boolean hosted = state.isHosted(); var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); - model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted); + model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); vocab = getOptionalChild(xml, "tokenizer-model") - .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted)) - .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted)); + .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) + .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null); @@ -50,13 +50,14 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + poolingStrategy = getOptionalChildValue(xml, "pooling-strategy").orElse(null); } - private static ModelReference resolveDefaultVocab(Element model, boolean hosted) { - if (hosted && model.hasAttribute("model-id")) { + private static ModelReference resolveDefaultVocab(Element model, DeployState state) { + if (state.isHosted() && model.hasAttribute("model-id")) { var implicitVocabId = model.getAttribute("model-id") + "-vocab"; return ModelIdResolver.resolveToModelReference( - "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true); + "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } @@ -75,5 +76,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + if (poolingStrategy != null) b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf(poolingStrategy)); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index ba8521a0089..966dbe8260a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -28,7 +28,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); for (Element element : XML.getChildren(xml, "model")) { var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; - langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted())); + langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state)); } specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null); maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index c0f49f3148d..96f653bf793 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -1,10 +1,10 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.xml; -import com.yahoo.config.FileReference; import com.yahoo.config.ModelReference; import com.yahoo.config.UrlReference; import com.yahoo.config.model.builder.xml.XmlHelper; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.text.XML; import org.w3c.dom.Element; @@ -80,25 +80,24 @@ public class ModelIdResolver { } - public static ModelReference resolveToModelReference(Element elem, boolean hosted) { + public static ModelReference resolveToModelReference(Element elem, DeployState state) { return resolveToModelReference( elem.getTagName(), XmlHelper.getOptionalAttribute(elem, "model-id"), - XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), hosted); + XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), state); } public static ModelReference resolveToModelReference( - String paramName, Optional<String> id, Optional<String> url, Optional<String> path, boolean hosted) { - if (id.isEmpty()) return ModelReference.unresolved( - Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new)); - else if (hosted) { - return ModelReference.unresolved( - id, Optional.of(UrlReference.valueOf(modelIdToUrl(paramName, id.get()))), Optional.empty()); - } else if (url.isEmpty() && path.isEmpty()) { - throw onlyModelIdInHostedException(paramName); - } else { - return ModelReference.unresolved( - Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new)); - } + String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) { + if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state); + else if (state.isHosted()) + return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get())), Optional.empty(), state); + else if (url.isEmpty() && path.isEmpty()) throw onlyModelIdInHostedException(paramName); + else return createModelReference(id, url, path, state); + } + + private static ModelReference createModelReference(Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) { + var fileRef = path.map(p -> state.getFileRegistry().addFile(p)); + return ModelReference.unresolved(id, url.map(UrlReference::valueOf), fileRef); } private static IllegalArgumentException onlyModelIdInHostedException(String paramName) { diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 4e7cb526efb..061e54740f1 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -80,7 +80,7 @@ ComponentDefinition = TypedComponentDefinition = attribute id { xsd:Name } & - (HuggingFaceEmbedder | HuggingFaceTokenizer) & + (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) & GenericConfig* & Component* @@ -94,14 +94,34 @@ HuggingFaceEmbedder = element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & element normalize { xsd:boolean }? & - element onnx-execution-mode { "parallel" | "sequential" }? & - element onnx-interop-threads { xsd:integer }? & - element onnx-intraop-threads { xsd:integer }? & - element onnx-gpu-device { xsd:integer }? + OnnxModelExecutionParams & + EmbedderPoolingStrategy HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & element model { attribute language { xsd:string }? & ModelReference }+ & element special-tokens { xsd:boolean }? & element max-length { xsd:integer }? & - element truncation { xsd:boolean }?
\ No newline at end of file + element truncation { xsd:boolean }? + +BertBaseEmbedder = + attribute type { "bert-embedder" } & + element transformer-model { ModelReference } & + element tokenizer-vocab { ModelReference } & + element max-tokens { xsd:nonNegativeInteger }? & + element transformer-input-ids { xsd:string }? & + element transformer-attention-mask { xsd:string }? & + element transformer-token-type-ids { xsd:string }? & + element transformer-output { xsd:string }? & + element transformer-start-sequence-token { xsd:integer }? & + element transformer-end-sequence-token { xsd:integer }? & + OnnxModelExecutionParams & + EmbedderPoolingStrategy + +OnnxModelExecutionParams = + element onnx-execution-mode { "parallel" | "sequential" }? & + element onnx-interop-threads { xsd:integer }? & + element onnx-intraop-threads { xsd:integer }? & + element onnx-gpu-device { xsd:integer }? + +EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }?
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def b/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def deleted file mode 100644 index 144dfbd0001..00000000000 --- a/config-model/src/test/cfg/application/embed/configdefinitions/embedding.bert-base-embedder.def +++ /dev/null @@ -1,30 +0,0 @@ -# Copy of this Vespa config stored here because Vespa config definitions are not -# available in unit tests, and are needed (by DomConfigPayloadBuilder.parseLeaf) -# Alternatively, we could make that not need it as it is not strictly necessaery. - -namespace=embedding - -# Wordpiece tokenizer -tokenizerVocab model - -transformerModel model - -# Max length of token sequence model can handle -transformerMaxTokens int default=384 - -# Pooling strategy -poolingStrategy enum { cls, mean } default=mean - -# Input names -transformerInputIds string default=input_ids -transformerAttentionMask string default=attention_mask -transformerTokenTypeIds string default=token_type_ids - -# Output name -transformerOutput string default=output_0 - -# Settings for ONNX model evaluation -onnxExecutionMode enum { parallel, sequential } default=sequential -onnxInterOpThreads int default=1 -onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n - diff --git a/config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def new file mode 100644 index 00000000000..87b80f1051a --- /dev/null +++ b/config-model/src/test/cfg/application/embed/configdefinitions/sentence-embedder.def @@ -0,0 +1,26 @@ +package=ai.vespa.example.paragraph + +# WordPiece tokenizer vocabulary +vocab model + +model model + +myValue string + +# Max length of token sequence model can handle +transforerMaxTokens int default=128 + +# Pooling strategy +poolingStrategy enum { cls, mean } default=mean + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask + +# Output name +transformerOutput string default=last_hidden_state + +# Settings for ONNX model evaluation +onnxExecutionMode enum { parallel, sequential } default=sequential +onnxInterOpThreads int default=1 +onnxIntraOpThreads int default=-4 diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 99c89bc4324..6823ef900ae 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -16,6 +16,7 @@ <onnx-intraop-threads>10</onnx-intraop-threads> <onnx-interop-threads>8</onnx-interop-threads> <onnx-gpu-device>1</onnx-gpu-device> + <pooling-strategy>mean</pooling-strategy> </component> <component id="hf-tokenizer" type="hugging-face-tokenizer"> @@ -25,15 +26,24 @@ <truncation>true</truncation> </component> - <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration"> - <config name="embedding.bert-base-embedder"> - <!-- model specifics --> - <transformerModel model-id="minilm-l6-v2" url="application-url"/> - <tokenizerVocab path="files/vocab.txt"/> + <component id="bert-embedder" type="bert-embedder"> + <!-- model specifics --> + <transformer-model model-id="minilm-l6-v2" url="application-url"/> + <tokenizer-vocab path="files/vocab.txt"/> + <max-tokens>512</max-tokens> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <transformer-start-sequence-token>101</transformer-start-sequence-token> + <transformer-end-sequence-token>102</transformer-end-sequence-token> - <!-- tunable parameters: number of threads etc --> - <onnxIntraOpThreads>4</onnxIntraOpThreads> - </config> + + <!-- tunable parameters: number of threads etc --> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>4</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> </component> <nodes> diff --git a/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def b/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def deleted file mode 100644 index 144dfbd0001..00000000000 --- a/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/embedding.bert-base-embedder.def +++ /dev/null @@ -1,30 +0,0 @@ -# Copy of this Vespa config stored here because Vespa config definitions are not -# available in unit tests, and are needed (by DomConfigPayloadBuilder.parseLeaf) -# Alternatively, we could make that not need it as it is not strictly necessaery. - -namespace=embedding - -# Wordpiece tokenizer -tokenizerVocab model - -transformerModel model - -# Max length of token sequence model can handle -transformerMaxTokens int default=384 - -# Pooling strategy -poolingStrategy enum { cls, mean } default=mean - -# Input names -transformerInputIds string default=input_ids -transformerAttentionMask string default=attention_mask -transformerTokenTypeIds string default=token_type_ids - -# Output name -transformerOutput string default=output_0 - -# Settings for ONNX model evaluation -onnxExecutionMode enum { parallel, sequential } default=sequential -onnxInterOpThreads int default=1 -onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n - diff --git a/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def b/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def new file mode 100644 index 00000000000..87b80f1051a --- /dev/null +++ b/config-model/src/test/cfg/application/embed_cloud_only/configdefinitions/sentence-embedder.def @@ -0,0 +1,26 @@ +package=ai.vespa.example.paragraph + +# WordPiece tokenizer vocabulary +vocab model + +model model + +myValue string + +# Max length of token sequence model can handle +transforerMaxTokens int default=128 + +# Pooling strategy +poolingStrategy enum { cls, mean } default=mean + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask + +# Output name +transformerOutput string default=last_hidden_state + +# Settings for ONNX model evaluation +onnxExecutionMode enum { parallel, sequential } default=sequential +onnxInterOpThreads int default=1 +onnxIntraOpThreads int default=-4 diff --git a/config-model/src/test/cfg/application/embed_cloud_only/services.xml b/config-model/src/test/cfg/application/embed_cloud_only/services.xml index 57db4f5bfae..e203ec56669 100644 --- a/config-model/src/test/cfg/application/embed_cloud_only/services.xml +++ b/config-model/src/test/cfg/application/embed_cloud_only/services.xml @@ -4,14 +4,11 @@ <container version="1.0"> - <component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration"> - <config name="embedding.bert-base-embedder"> - <!-- No fallback to url or path when deploying outside cloud --> - <transformerModel model-id="minilm-l6-v2"/> - <tokenizerVocab path="files/vocab.txt"/> - - <!-- tunable parameters: number of threads etc --> - <onnxIntraOpThreads>4</onnxIntraOpThreads> + <component id="transformer" class="ai.vespa.example.paragraph.ApplicationSpecificEmbedder" bundle="app"> + <config name='ai.vespa.example.paragraph.sentence-embedder'> + <model model-id="minilm-l6-v2"/> + <vocab path="files/vocab.txt"/> + <myValue>foo</myValue> </config> </component> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 69981233c3f..2a82daef9e3 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -2,9 +2,13 @@ package com.yahoo.vespa.model.container.xml; import com.yahoo.component.ComponentId; +import com.yahoo.config.InnerNode; +import com.yahoo.config.ModelNode; +import com.yahoo.config.ModelReference; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -13,6 +17,7 @@ import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigPayloadBuilder; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; @@ -35,55 +40,18 @@ import static org.junit.jupiter.api.Assertions.fail; public class EmbedderTestCase { - private static final String BUNDLED_EMBEDDER_CLASS = "ai.vespa.embedding.BertBaseEmbedder"; - private static final String BUNDLED_EMBEDDER_CONFIG = "embedding.bert-base-embedder"; - - @Test - void testBundledEmbedder_selfhosted() throws IOException, SAXException { - String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel id='my_model_id' url='my-model-url' />" + - " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" + - " </config>" + - "</component>"; - String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel id='my_model_id' url='my-model-url' />" + - " <tokenizerVocab id='my_vocab_id' url='my-vocab-url' />" + - " </config>" + - "</component>"; - assertTransform(input, component, false); - } - - @Test - void testBundledEmbedder_hosted() throws IOException, SAXException { - String input = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel model-id='minilm-l6-v2' />" + - " <tokenizerVocab model-id='bert-base-uncased' path='ignored.txt'/>" + - " </config>" + - "</component>"; - String component = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "' bundle='model-integration'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" + - " <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" + - " </config>" + - "</component>"; - assertTransform(input, component, true); - } - @Test void testApplicationComponentWithModelReference_hosted() throws IOException, SAXException { - String input = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel model-id='minilm-l6-v2' />" + - " <tokenizerVocab model-id='bert-base-uncased' />" + + String input = "<component id='test' class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' bundle='app'>" + + " <config name='ai.vespa.example.paragraph.sentence-embedder'>" + + " <model model-id='minilm-l6-v2' />" + + " <vocab model-id='bert-base-uncased' />" + " </config>" + "</component>"; - String component = "<component id='test' class='ApplicationSpecificEmbedder' bundle='model-integration'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" + - " <tokenizerVocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" + + String component = "<component id='test' class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder' bundle='app'>" + + " <config name='ai.vespa.example.paragraph.sentence-embedder'>" + + " <model model-id='minilm-l6-v2' url='https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx' />" + + " <vocab model-id='bert-base-uncased' url='https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt' />" + " </config>" + "</component>"; assertTransform(input, component, true); @@ -91,64 +59,65 @@ public class EmbedderTestCase { @Test void testUnknownModelId_hosted() throws IOException, SAXException { - String embedder = "<component id='test' class='" + BUNDLED_EMBEDDER_CLASS + "'>" + - " <config name='" + BUNDLED_EMBEDDER_CONFIG + "'>" + - " <transformerModel model-id='my_model_id' />" + - " <tokenizerVocab model-id='my_vocab_id' />" + + String embedder = "<component id='test' class='ai.vespa.example.paragraph.ApplicationSpecificEmbedder'>" + + " <config name='ai.vespa.example.paragraph.sentence-embedder'>" + + " <model model-id='my_model_id' />" + + " <vocab model-id='my_vocab_id' />" + " </config>" + "</component>"; assertTransformThrows(embedder, - "Unknown model id 'my_model_id' on 'transformerModel'", + "Unknown model id 'my_model_id' on 'model'", true); } @Test - void testApplicationPackageWithEmbedder_selfhosted() throws Exception { - Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); - VespaModel model = loadModel(applicationDir, false); - ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + void huggingfaceEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(768, tokenizerCfg.maxLength()); + } - Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); - ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); - assertEquals("minilm-l6-v2 application-url \"\"", config.getObject("transformerModel").getValue()); - assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue()); - assertEquals("4", config.getObject("onnxIntraOpThreads").getValue()); - - { - var hfEmbedder = (HuggingFaceEmbedder)containerCluster.getComponentsMap().get(new ComponentId("hf-embedder")); - assertEquals("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", hfEmbedder.getClassId().getName()); - var cfgBuilder = new HuggingFaceEmbedderConfig.Builder(); - hfEmbedder.getConfig(cfgBuilder); - var cfg = cfgBuilder.build(); - assertEquals("my_input_ids", cfg.transformerInputIds()); - } - { - var hfTokenizer = (HuggingFaceTokenizer)containerCluster.getComponentsMap().get(new ComponentId("hf-tokenizer")); - assertEquals("com.yahoo.language.huggingface.HuggingFaceTokenizer", hfTokenizer.getClassId().getName()); - var cfgBuilder = new HuggingFaceTokenizerConfig.Builder(); - hfTokenizer.getConfig(cfgBuilder); - var cfg = cfgBuilder.build(); - assertEquals(768, cfg.maxLength()); - } + @Test + void huggingfaceEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(768, tokenizerCfg.maxLength()); } + @Test - void passesXmlValdiation() { - new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create(); + void bertEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertBertEmbedderComponentPresent(cluster); + assertEquals("application-url", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value()); } @Test - void testApplicationPackageWithEmbedder_hosted() throws Exception { - Path applicationDir = Path.fromString("src/test/cfg/application/embed/"); - VespaModel model = loadModel(applicationDir, true); - ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); + void bertEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertBertEmbedderComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", + modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertTrue(modelReference(embedderCfg, "tokenizerVocab").url().isEmpty()); + assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value()); + } - Component<?, ?> transformer = containerCluster.getComponentsMap().get(new ComponentId("transformer")); - ConfigPayloadBuilder config = transformer.getUserConfigs().get(new ConfigDefinitionKey("bert-base-embedder", "embedding")); - assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"", - config.getObject("transformerModel").getValue()); - assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue()); - assertEquals("4", config.getObject("onnxIntraOpThreads").getValue()); + @Test + void passesXmlValidation() { + new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create(); } @Test @@ -184,7 +153,7 @@ public class EmbedderTestCase { fail("Expected failure"); } catch (IllegalArgumentException e) { - assertEquals("transformerModel is configured with only a 'model-id'. Add a 'path' or 'url' to deploy this outside Vespa Cloud", + assertEquals("model is configured with only a 'model-id'. Add a 'path' or 'url' to deploy this outside Vespa Cloud", Exceptions.toMessageString(e)); } } @@ -244,4 +213,39 @@ public class EmbedderTestCase { return (Element) doc.getFirstChild(); } + private static HuggingFaceTokenizerConfig assertHuggingfaceTokenizerComponentPresent(ApplicationContainerCluster cluster) { + var hfTokenizer = (HuggingFaceTokenizer) cluster.getComponentsMap().get(new ComponentId("hf-tokenizer")); + assertEquals("com.yahoo.language.huggingface.HuggingFaceTokenizer", hfTokenizer.getClassId().getName()); + var cfgBuilder = new HuggingFaceTokenizerConfig.Builder(); + hfTokenizer.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + + private static HuggingFaceEmbedderConfig assertHuggingfaceEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var hfEmbedder = (HuggingFaceEmbedder) cluster.getComponentsMap().get(new ComponentId("hf-embedder")); + assertEquals("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", hfEmbedder.getClassId().getName()); + var cfgBuilder = new HuggingFaceEmbedderConfig.Builder(); + hfEmbedder.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); + assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); + var cfgBuilder = new BertBaseEmbedderConfig.Builder(); + bertEmbedder.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + + // Ugly hack to read underlying model reference from config instance + private static ModelReference modelReference(InnerNode cfg, String name) { + try { + var f = cfg.getClass().getDeclaredField(name); + f.setAccessible(true); + return ((ModelNode) f.get(cfg)).getModelReference(); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + } diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/configdefinitions/src/vespa/embedding.bert-base-embedder.def index 2d8e840377b..2d8e840377b 100644 --- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def +++ b/configdefinitions/src/vespa/embedding.bert-base-embedder.def diff --git a/configdefinitions/src/vespa/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def index 36957004e02..7ea4227b3cd 100644 --- a/configdefinitions/src/vespa/hugging-face-embedder.def +++ b/configdefinitions/src/vespa/hugging-face-embedder.def @@ -21,6 +21,8 @@ transformerOutput string default=last_hidden_state # Normalize tensors from tokenizer normalize bool default=false +poolingStrategy enum { cls, mean } default=mean + # Settings for ONNX model evaluation transformerExecutionMode enum { parallel, sequential } default=sequential transformerInterOpThreads int default=1 diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index b172ef7beee..a12424c7d12 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -10,7 +10,6 @@ import com.yahoo.language.process.Embedder; import com.yahoo.language.wordpiece.WordPieceEmbedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.util.ArrayList; @@ -39,7 +38,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String tokenTypeIdsName; private final String outputName; - private final String poolingStrategy; + private final PoolingStrategy poolingStrategy; private final WordPieceEmbedder tokenizer; private final OnnxEvaluator evaluator; @@ -53,7 +52,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); - poolingStrategy = config.poolingStrategy().toString(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); options.setExecutionMode(config.onnxExecutionMode().toString()); @@ -124,20 +123,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { Tensor tokenEmbeddings = outputs.get(outputName); - Tensor.Builder builder = Tensor.Builder.of(type); - if (poolingStrategy.equals("mean")) { // average over tokens - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - } else { // CLS - use first token - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); - } - } - return builder.build(); + return poolingStrategy.toSentenceEmbedding(type, tokenEmbeddings, attentionMask); } private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) { diff --git a/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java new file mode 100644 index 00000000000..28104d8eeef --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/PoolingStrategy.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +/** + * @author bjorncs + */ +public enum PoolingStrategy { + MEAN { + @Override + public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask) { + var builder = Tensor.Builder.of(type); + var summedEmbeddings = tokenEmbeddings.sum("d1"); + var summedAttentionMask = attentionMask.expand("d0").sum("d1"); + var averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(averaged.get(TensorAddress.of(0, i)), i); + } + return builder.build(); + } + }, + CLS { + @Override + public Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor ignored) { + var builder = Tensor.Builder.of(type); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); + } + return builder.build(); + } + }; + + public abstract Tensor toSentenceEmbedding(TensorType type, Tensor tokenEmbeddings, Tensor attentionMask); + + public static PoolingStrategy fromString(String strategy) { + return switch (strategy.toLowerCase()) { + case "mean" -> MEAN; + case "cls" -> CLS; + default -> throw new IllegalArgumentException("Unknown pooling strategy '%s'".formatted(strategy)); + }; + } +} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 01804656bb6..f93b1a3c1f8 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -1,5 +1,6 @@ package ai.vespa.embedding.huggingface; +import ai.vespa.embedding.PoolingStrategy; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; @@ -28,6 +29,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; + private final PoolingStrategy poolingStrategy; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { @@ -42,6 +44,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { .setTruncation(true) .setMaxLength(config.transformerMaxTokens()) .build(); + poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); |