diff options
author | Bjørn Christian Seime <bjorn.christian@seime.no> | 2023-09-22 08:18:42 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-22 08:18:42 +0200 |
commit | 0a5a58845e53ae6995b89427808010cd55996458 (patch) | |
tree | 9ced00fe0437b67e539466d1cc02141aeb58a817 | |
parent | 29c5e644e8f68b133d7fcce1a24c4ef08b7e2ddc (diff) | |
parent | 474215107d7ed448dbc71bef0685de913805addc (diff) |
Merge pull request #28599 from vespa-engine/jobergum/colbert-embedder
Add ColBERT embedder
11 files changed, 818 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java index d0e1ede2cfa..7501f6162c7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java @@ -7,10 +7,11 @@ import com.yahoo.config.model.producer.AnyConfigProducer; import com.yahoo.config.model.producer.TreeConfigProducer; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; -import com.yahoo.vespa.model.container.component.BertEmbedder; -import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.BertEmbedder; +import com.yahoo.vespa.model.container.component.ColBertEmbedder; +import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder; import org.w3c.dom.Element; @@ -46,6 +47,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state); case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state); case "bert-embedder" -> new BertEmbedder(spec, state); + case "colbert-embedder" -> new ColBertEmbedder(spec, state); default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type)); }; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java new file mode 100644 index 00000000000..c0fdfe3dc64 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -0,0 +1,93 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.util.Optional; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; +import static com.yahoo.text.XML.getChildValue; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + + +/** + * @author bergum + */ +public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { + private final ModelReference model; + private final ModelReference vocab; + + private final Integer maxQueryTokens; + + private final Integer maxDocumentTokens; + + private final Integer transformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final Integer transformerMaskToken; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + + private final String transformerOutput; + private final String onnxExecutionMode; + private final Integer onnxInteropThreads; + private final Integer onnxIntraopThreads; + private final Integer onnxGpuDevice; + + public ColBertEmbedder(Element xml, DeployState state) { + super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); + var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); + model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); + vocab = getOptionalChild(xml, "tokenizer-model") + .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) + .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); + maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); + transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); + onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null); + onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); + onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); + onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + + } + + private static ModelReference resolveDefaultVocab(Element model, DeployState state) { + if (state.isHosted() && model.hasAttribute("model-id")) { + var implicitVocabId = model.getAttribute("model-id") + "-vocab"; + return ModelIdResolver.resolveToModelReference( + "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + } + throw new IllegalArgumentException("'tokenizer-model' must be specified"); + } + + @Override + public void getConfig(ColBertEmbedderConfig.Builder b) { + b.transformerModel(model).tokenizerPath(vocab); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens); + if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens); + if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); + if (onnxExecutionMode != null) b.transformerExecutionMode( + ColBertEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode)); + if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); + if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); + if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + } +} diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index ba7e2b6674e..e0d5e6a3344 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -80,7 +80,7 @@ ComponentDefinition = TypedComponentDefinition = attribute id { xsd:Name } & - (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) & + (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder) & GenericConfig* & Component* @@ -110,15 +110,36 @@ BertBaseEmbedder = element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & - element transformer-start-sequence-token { xsd:integer }? & - element transformer-end-sequence-token { xsd:integer }? & + StartOfSequence & + EndOfSequence & OnnxModelExecutionParams & EmbedderPoolingStrategy + +ColBertEmbedder = + attribute type { "colbert-embedder" } & + element transformer-model { ModelReference } & + element tokenizer-model { ModelReference }? & + element max-tokens { xsd:positiveInteger }? & + element max-query-tokens { xsd:positiveInteger }? & + element max-document-tokens { xsd:positiveInteger }? & + element transformer-mask-token { xsd:integer }? & + element transformer-input-ids { xsd:string }? & + element transformer-attention-mask { xsd:string }? & + element transformer-token-type-ids { xsd:string }? & + element transformer-output { xsd:string }? & + element normalize { xsd:boolean }? & + OnnxModelExecutionParams & + StartOfSequence & + EndOfSequence + OnnxModelExecutionParams = element onnx-execution-mode { "parallel" | "sequential" }? & element onnx-interop-threads { xsd:integer }? & element onnx-intraop-threads { xsd:integer }? & element onnx-gpu-device { xsd:integer }? -EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }?
\ No newline at end of file +EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? + +StartOfSequence = element transformer-start-sequence-token { xsd:integer }? +EndOfSequence = element transformer-end-sequence-token { xsd:integer }?
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 70eef7ea54a..efb33d36761 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -43,6 +43,25 @@ <onnx-gpu-device>1</onnx-gpu-device> </component> + <component id="colbert" type="colbert-embedder"> + <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> + <max-tokens>1024</max-tokens> + <max-query-tokens>32</max-query-tokens> + <max-document-tokens>512</max-document-tokens> + <transformer-start-sequence-token>101</transformer-start-sequence-token> + <transformer-end-sequence-token>102</transformer-end-sequence-token> + <transformer-mask-token>103</transformer-mask-token> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>10</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 42b78db66b1..5832445d0d7 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -21,6 +22,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; @@ -96,6 +98,29 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + void colBertEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + void colBertEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } @Test void bertEmbedder_selfhosted() throws Exception { @@ -233,6 +258,14 @@ public class EmbedderTestCase { return cfgBuilder.build(); } + private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); + var cfgBuilder = new ColBertEmbedderConfig.Builder(); + colbert.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); diff --git a/configdefinitions/src/vespa/CMakeLists.txt b/configdefinitions/src/vespa/CMakeLists.txt index 29ed0f53421..bb29db2e3e3 100644 --- a/configdefinitions/src/vespa/CMakeLists.txt +++ b/configdefinitions/src/vespa/CMakeLists.txt @@ -88,5 +88,6 @@ install_config_definition(dataplane-proxy.def cloud.config.dataplane-proxy.def) install_config_definition(hugging-face-embedder.def embedding.huggingface.hugging-face-embedder.def) install_config_definition(hugging-face-tokenizer.def language.huggingface.config.hugging-face-tokenizer.def) install_config_definition(bert-base-embedder.def embedding.bert-base-embedder.def) +install_config_definition(col-bert-embedder.def embedding.col-bert-embedder.def) install_config_definition(cloud-data-plane-filter.def jdisc.http.filter.security.cloud.config.cloud-data-plane-filter.def) install_config_definition(cloud-token-data-plane-filter.def jdisc.http.filter.security.cloud.config.cloud-token-data-plane-filter.def) diff --git a/configdefinitions/src/vespa/col-bert-embedder.def b/configdefinitions/src/vespa/col-bert-embedder.def new file mode 100644 index 00000000000..c7944847d8b --- /dev/null +++ b/configdefinitions/src/vespa/col-bert-embedder.def @@ -0,0 +1,36 @@ + +namespace=embedding + +# Path to tokenizer.json +tokenizerPath model + +# Path to model.onnx +transformerModel model + +# Max query tokens for ColBERT +maxQueryTokens int default=32 + +# Max document query tokens for ColBERT +maxDocumentTokens int default=512 + +# Max length of token sequence model can handle +transformerMaxTokens int default=512 + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask + +# special token ids +transformerStartSequenceToken int default=101 +transformerEndSequenceToken int default=102 +transformerMaskToken int default=103 + +# Output name +transformerOutput string default=contextual + +# Settings for ONNX model evaluation +transformerExecutionMode enum { parallel, sequential } default=sequential +transformerInterOpThreads int default=1 +transformerIntraOpThreads int default=-4 +# GPU device id, -1 for CPU +transformerGpuDevice int default=0 diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java new file mode 100644 index 00000000000..aafb9877c27 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -0,0 +1,306 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import com.yahoo.api.annotations.Beta; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.huggingface.HuggingFaceTokenizer; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import java.nio.file.Paths; +import java.util.Map; +import java.util.List; +import java.util.ArrayList; +import java.util.Set; +import java.util.HashSet; +import java.util.BitSet; +import java.util.Arrays; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; + +/** + * A ColBERT embedder implementation that maps text to multiple vectors, one vector per subword id. + * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. + * + * See col-bert-embedder.def for configurable parameters. + * @author bergum + */ +@Beta +public class ColBertEmbedder extends AbstractComponent implements Embedder { + private final Embedder.Runtime runtime; + private final String inputIdsName; + private final String attentionMaskName; + + private final String outputName; + + private final HuggingFaceTokenizer tokenizer; + private final OnnxEvaluator evaluator; + + private final int maxTransformerTokens; + private final int maxQueryTokens; + private final int maxDocumentTokens; + + private final long startSequenceToken; + private final long endSequenceToken; + private final long maskSequenceToken; + + + @Inject + public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) { + this.runtime = runtime; + inputIdsName = config.transformerInputIds(); + attentionMaskName = config.transformerAttentionMask(); + outputName = config.transformerOutput(); + maxTransformerTokens = config.transformerMaxTokens(); + if(config.maxDocumentTokens() > maxTransformerTokens) + throw new IllegalArgumentException("maxDocumentTokens must be less than or equal to transformerMaxTokens"); + maxDocumentTokens = config.maxDocumentTokens(); + maxQueryTokens = config.maxQueryTokens(); + startSequenceToken = config.transformerStartSequenceToken(); + endSequenceToken = config.transformerEndSequenceToken(); + maskSequenceToken = config.transformerMaskToken(); + + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); + var onnxOpts = new OnnxEvaluatorOptions(); + + if (config.transformerGpuDevice() >= 0) + onnxOpts.setGpuDevice(config.transformerGpuDevice()); + onnxOpts.setExecutionMode(config.transformerExecutionMode().toString()); + onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts); + validateModel(); + } + + public void validateModel() { + Map<String, TensorType> inputs = evaluator.getInputInfo(); + validateName(inputs, inputIdsName, "input"); + validateName(inputs, attentionMaskName, "input"); + Map<String, TensorType> outputs = evaluator.getOutputInfo(); + validateName(outputs, outputName, "output"); + } + + private void validateName(Map<String, TensorType> types, String name, String type) { + if (!types.containsKey(name)) { + throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + + "Model contains: " + String.join(",", types.keySet())); + } + } + + @Override + public List<Integer> embed(String text, Context context) { + throw new UnsupportedOperationException("This embedder only supports embed with tensor type"); + } + + @Override + public Tensor embed(String text, Context context, TensorType tensorType) { + if(!verifyTensorType(tensorType)) { + throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination." + + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType.toString()); + } + if (context.getDestination().startsWith("query")) { + return embedQuery(text, context, tensorType); + } else { + return embedDocument(text, context, tensorType); + } + } + + @Override + public void deconstruct() { + evaluator.close(); + tokenizer.close(); + } + + protected Tensor embedQuery(String text, Context context, TensorType tensorType) { + if(tensorType.valueType() == TensorType.Value.INT8) + throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); + + long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document. + + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + List<Long> ids = encoding.ids(); + if (ids.size() > maxQueryTokens - 3) + ids = ids.subList(0, maxQueryTokens - 3); + + List<Long> inputIds = new ArrayList<>(maxQueryTokens); + List<Long> attentionMask = new ArrayList<>(maxQueryTokens); + + inputIds.add(startSequenceToken); + inputIds.add(Q_TOKEN_ID); + inputIds.addAll(ids); + inputIds.add(endSequenceToken); + int length = inputIds.size(); + + int padding = maxQueryTokens - length; + for (int i = 0; i < padding; i++) + inputIds.add(maskSequenceToken); + + for (int i = 0; i < length; i++) + attentionMask.add((long) 1); + for (int i = 0; i < padding; i++) + attentionMask.add((long) 0);//Do not attend to mask paddings + + Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + + var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0")); + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + Tensor tokenEmbeddings = outputs.get(outputName); + IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + + int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); + if(dims != result.shape()[1]) { + throw new IllegalArgumentException("Token dimensionality does not" + + " match indexed dimensionality of " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(tensorType); + for (int token = 0; token < result.shape()[0]; token++) + for (int d = 0; d < result.shape()[1]; d++) + builder.cell(TensorAddress.of(token, d), result.get(TensorAddress.of(token, d))); + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return builder.build(); + } + + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + long D_TOKEN_ID = 2; // [unused1] token id used during training to differentiate query versus document. + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + List<Long> ids = encoding.ids().stream().filter(token + -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList(); + ; + + if (ids.size() > maxDocumentTokens - 3) + ids = ids.subList(0, maxDocumentTokens - 3); + List<Long> inputIds = new ArrayList<>(maxDocumentTokens); + List<Long> attentionMask = new ArrayList<>(maxDocumentTokens); + inputIds.add(startSequenceToken); + inputIds.add(D_TOKEN_ID); + inputIds.addAll(ids); + inputIds.add(endSequenceToken); + for (int i = 0; i < inputIds.size(); i++) + attentionMask.add((long) 1); + + Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + + var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0")); + + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + Tensor tokenEmbeddings = outputs.get(outputName); + IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + Tensor contextualEmbeddings; + if(tensorType.valueType() == TensorType.Value.INT8) { + contextualEmbeddings = toBitTensor(result, tensorType); + } else { + contextualEmbeddings = toFloatTensor(result, tensorType); + } + + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return contextualEmbeddings; + } + + public static Tensor toFloatTensor(IndexedTensor result, TensorType type) { + int size = type.indexedSubtype().dimensions().size(); + if (size != 1) + throw new IllegalArgumentException("Indexed tensor must have one dimension"); + int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int)result.shape()[1]; + if(resultDim != dims) { + throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDim + + " + dimensions into tensor with " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(type); + for (int token = 0; token < result.shape()[0]; token++) { + for (int d = 0; d < result.shape()[1]; d++) { + var value = result.get(TensorAddress.of(token, d)); + builder.cell(TensorAddress.of(token,d),value); + } + } + return builder.build(); + } + + public static Tensor toBitTensor(IndexedTensor result, TensorType type) { + if (type.valueType() != TensorType.Value.INT8) + throw new IllegalArgumentException("Only a int8 tensor type can be" + + " the destination of bit packing"); + int size = type.indexedSubtype().dimensions().size(); + if (size != 1) + throw new IllegalArgumentException("Indexed tensor must have one dimension"); + int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int)result.shape()[1]; + if(resultDim/8 != dims) { + throw new IllegalArgumentException("Not possible to pack " + resultDim + + " + dimensions into " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(type); + for (int token = 0; token < result.shape()[0]; token++) { + BitSet bitSet = new BitSet(8); + int key = 0; + for (int d = 0; d < result.shape()[1]; d++) { + var value = result.get(TensorAddress.of(token, d)); + int bitIndex = 7 - (d % 8); + if (value > 0.0) { + bitSet.set(bitIndex); + } else { + bitSet.clear(bitIndex); + } + if ((d + 1) % 8 == 0) { + byte[] bytes = bitSet.toByteArray(); + byte packed = (bytes.length == 0) ? 0 : bytes[0]; + builder.cell(TensorAddress.of(token, key), packed); + key++; + bitSet = new BitSet(8); + } + } + } + return builder.build(); + } + + protected boolean verifyTensorType(TensorType target) { + return target.dimensions().size() == 2 && + target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; + } + + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { + int size = input.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; ++i) { + builder.cell(input.get(i), i); + } + return builder.build(); + } + + private static final Set<Long> PUNCTUATION_TOKEN_IDS = new HashSet<>( + Arrays.asList(999L, 1000L, 1001L, 1002L, 1003L, 1004L, 1005L, 1006L, + 1007L, 1008L, 1009L, 1010L, 1011L, 1012L, 1013L, 1024L, + 1025L, 1026L, 1027L, 1028L, 1029L, 1030L, 1031L, 1032L, + 1033L, 1034L, 1035L, 1036L, 1063L, 1064L, 1065L, 1066L)); +} diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java new file mode 100644 index 00000000000..8516f6e6689 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -0,0 +1,126 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.config.ModelReference; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MixedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; + +public class ColBertEmbedderTest { + + @Test + public void testPacking() { + assertPackedRight( + "" + + "tensor<float>(d1[6],d2[8]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1]," + + "[0, 0, 0, 0, 0, 1, 0, 1]," + + "[0, 0, 0, 0, 0, 0, 1, 1]," + + "[0, 1, 1, 1, 1, 1, 1, 1]," + + "[1, 0, 0, 0, 0, 0, 0, 0]," + + "[1, 1, 1, 1, 1, 1, 1, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[1])"), + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}" + ); + assertPackedRight( + "" + + "tensor<float>(d1[2],d2[16]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[2])"), + "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}" + ); + } + + @Test + public void testEmbedder() { + assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext); + assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext); + assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because int8 is not supported for query context + assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); + }); + assertThrows(IllegalArgumentException.class, () -> { + //throws because 16 is less than model output (128) and we want float + assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); + }); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because 128/8 does not fit into 15 + assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); + }); + } + + @Test + public void testLenghtLimits() { + StringBuilder sb = new StringBuilder(); + for(int i = 0; i < 1024; i++) { + sb.append("annoyance"); + sb.append(" "); + } + String text = sb.toString(); + Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + assertEquals(512*128,fullFloat.size()); + + Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); + assertEquals(32*128,query.size()); + + Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); + assertEquals(512*16,binaryRep.size()); + + Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); + // 4 tokens, 16 bytes each = 64 bytes + //because of CLS, special, sequence, SEP + assertEquals(4*16,shortDoc.size());; + } + + static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { + TensorType destType = TensorType.fromSpec(tensorSpec); + Tensor result = embedder.embed(text, context, destType); + assertEquals(destType,result.type()); + MixedTensor mixedTensor = (MixedTensor) result; + if(context == queryContext) { + assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); + } + return result; + } + + static void assertPackedRight(String numbers, TensorType destination,String expected) { + Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination); + assertEquals(expected,packed.toString()); + } + + static final Embedder embedder; + static final Embedder.Context indexingContext; + static final Embedder.Context queryContext; + static { + indexingContext = new Embedder.Context("schema.indexing"); + queryContext = new Embedder.Context("query(qt)"); + embedder = getEmbedder(); + } + private static Embedder getEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/tokenizer.json"; + String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); + ColBertEmbedderConfig.Builder builder = new ColBertEmbedderConfig.Builder(); + builder.tokenizerPath(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerGpuDevice(-1); + return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } +}
\ No newline at end of file diff --git a/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx Binary files differnew file mode 100644 index 00000000000..5ab1060e59e --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx diff --git a/model-integration/src/test/models/onnx/transformer/tokenizer.json b/model-integration/src/test/models/onnx/transformer/tokenizer.json new file mode 100644 index 00000000000..28340f289bb --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/tokenizer.json @@ -0,0 +1,175 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 100, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 101, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 102, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 103, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": { + "type": "BertNormalizer", + "clean_text": true, + "handle_chinese_chars": true, + "strip_accents": null, + "lowercase": true + }, + "pre_tokenizer": { + "type": "BertPreTokenizer" + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 1 + } + } + ], + "special_tokens": { + "[CLS]": { + "id": "[CLS]", + "ids": [101], + "tokens": ["[CLS]"] + }, + "[SEP]": { + "id": "[SEP]", + "ids": [102], + "tokens": ["[SEP]"] + } + } + }, + "decoder": { + "type": "WordPiece", + "prefix": "##", + "cleanup": true + }, + "model": { + "type": "WordPiece", + "unk_token": "[UNK]", + "continuing_subword_prefix": "##", + "max_input_chars_per_word": 100, + "vocab": { + "[PAD]": 0, + "[unused0]": 1, + "[unused1]": 2, + "[UNK]": 100, + "[CLS]": 101, + "[SEP]": 102, + "[MASK]": 103, + "a": 1037, + "b": 1038, + "c": 1039, + "d": 1040, + "e": 1041, + "f": 1042, + "g": 1043, + "h": 1044, + "i": 1045, + "j": 1046, + "k": 1047, + "l": 1048, + "m": 1049, + "n": 1050, + "o": 1051, + "p": 1052, + "q": 1053, + "r": 1054, + "s": 1055, + "t": 1056, + "u": 1057, + "v": 1058, + "w": 1059, + "x": 1060, + "y": 1061, + "z": 1062 + } + } +} |