diff options
author | Lester Solbakken <lesters@oath.com> | 2022-04-01 09:42:13 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-04-01 09:42:13 +0200 |
commit | 4cc5a56f7006c40b059855f345ded365ace8550c (patch) | |
tree | 6142647b7fd6240e64d1ef21ca74ec50e920f004 /model-integration | |
parent | 02151a078dd5f45defaa42b19bf127bc2c999944 (diff) |
Revert "Revert "Add bert base embedder""
This reverts commit 80556883dcdc350e2b5cfebde8ef482baeb36872.
Diffstat (limited to 'model-integration')
7 files changed, 320 insertions, 0 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 41139394690..d064a3ff709 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -45,6 +45,18 @@ <version>${project.version}</version> <scope>provided</scope> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>linguistics</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>linguistics-components</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> <dependency> <groupId>com.google.guava</groupId> diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java new file mode 100644 index 00000000000..42e3d653359 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -0,0 +1,165 @@ +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import com.yahoo.component.annotation.Inject; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.wordpiece.WordPieceEmbedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + + +/** + * A BERT Base compatible embedder. This embedder uses a WordPiece embedder to + * produce a token sequence that is input to a transformer model. A BERT base + * compatible transformer model must have three inputs: + * + * - A token sequence (input_ids) + * - An attention mask (attention_mask) + * - Token types for cross encoding (token_type_ids) + * + * See bert-base-embedder.def for configurable parameters. + * + * @author lesters + */ +public class BertBaseEmbedder implements Embedder { + + private final static int TOKEN_CLS = 101; // [CLS] + private final static int TOKEN_SEP = 102; // [SEP] + + private final int maxTokens; + private final String inputIdsName; + private final String attentionMaskName; + private final String tokenTypeIdsName; + private final String outputName; + private final String poolingStrategy; + + private final WordPieceEmbedder tokenizer; + private final OnnxEvaluator evaluator; + + @Inject + public BertBaseEmbedder(BertBaseEmbedderConfig config) { + maxTokens = config.transformerMaxTokens(); + inputIdsName = config.transformerInputIds(); + attentionMaskName = config.transformerAttentionMask(); + tokenTypeIdsName = config.transformerTokenTypeIds(); + outputName = config.transformerOutput(); + poolingStrategy = config.poolingStrategy().toString(); + + OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); + options.setExecutionMode(config.onnxExecutionMode().toString()); + options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads())); + options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); + + // Todo: use either file or url + tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build(); + evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options); + + validateModel(); + } + + private void validateModel() { + Map<String, TensorType> inputs = evaluator.getInputInfo(); + validateName(inputs, inputIdsName, "input"); + validateName(inputs, attentionMaskName, "input"); + validateName(inputs, tokenTypeIdsName, "input"); + + Map<String, TensorType> outputs = evaluator.getOutputInfo(); + validateName(outputs, outputName, "output"); + } + + private void validateName(Map<String, TensorType> types, String name, String type) { + if ( ! types.containsKey(name)) { + throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + + "Model contains: " + String.join(",", types.keySet())); + } + } + + @Override + public List<Integer> embed(String text, Context context) { + return tokenizer.embed(text, context); + } + + @Override + public Tensor embed(String text, Context context, TensorType type) { + if (type.dimensions().size() != 1) { + throw new IllegalArgumentException("Error in embedding to type '" + type + "': should only have one dimension."); + } + if (!type.dimensions().get(0).isIndexed()) { + throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed."); + } + List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens); + return embedTokens(tokens, type); + } + + Tensor embedTokens(List<Integer> tokens, TensorType type) { + Tensor inputSequence = createTensorRepresentation(tokens, "d1"); + Tensor attentionMask = createAttentionMask(inputSequence); + Tensor tokenTypeIds = createTokenTypeIds(inputSequence); + + Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0"), + tokenTypeIdsName, tokenTypeIds.expand("d0")); + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + + Tensor tokenEmbeddings = outputs.get(outputName); + + Tensor.Builder builder = Tensor.Builder.of(type); + if (poolingStrategy.equals("mean")) { // average over tokens + Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); + Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); + Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(averaged.get(TensorAddress.of(0,i)), i); + } + } else { // CLS - use first token + for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { + builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); + } + } + return builder.build(); + } + + private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) { + List<Integer> tokens = new ArrayList<>(); + tokens.add(TOKEN_CLS); + tokens.addAll(embed(text, context)); + tokens.add(TOKEN_SEP); + if (tokens.size() > maxLength) { + tokens = tokens.subList(0, maxLength-1); + tokens.add(TOKEN_SEP); + } + return tokens; + } + + private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) { + int size = input.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; ++i) { + builder.cell(input.get(i), i); + } + return builder.build(); + } + + private static Tensor createAttentionMask(Tensor d) { + return d.map((x) -> x > 0 ? 1:0); + } + + private static Tensor createTokenTypeIds(Tensor d) { + return d.map((x) -> x > 0 ? 0:0); + } + + private int modifyThreadCount(int numThreads) { + if (numThreads >= 0) + return numThreads; + return Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / (-1 * numThreads))); + } + +} diff --git a/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def new file mode 100644 index 00000000000..7e3ff151466 --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def @@ -0,0 +1,27 @@ +package=ai.vespa.embedding + +# Transformer model settings +transformerModelUrl url + +# Max length of token sequence model can handle +transformerMaxTokens int default=384 + +# Pooling strategy +poolingStrategy enum { cls, mean } default=mean + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask +transformerTokenTypeIds string default=token_type_ids + +# Output name +transformerOutput string default=output_0 + +# Settings for ONNX model evaluation +onnxExecutionMode enum { parallel, sequential } default=sequential +onnxInterOpThreads int default=1 +onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n + +# Settings for wordpiece tokenizer +tokenizerVocabUrl url + diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java new file mode 100644 index 00000000000..519f24795ca --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -0,0 +1,29 @@ +package ai.vespa.embedding; + +import com.yahoo.config.UrlReference; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class BertBaseEmbedderTest { + + @Test + public void testEmbedder() { + BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); + builder.tokenizerVocabUrl(new UrlReference("src/test/models/onnx/transformer/dummy_vocab.txt")); + builder.transformerModelUrl(new UrlReference("src/test/models/onnx/transformer/dummy_transformer.onnx")); + BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); + + TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); + List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer + Tensor embedding = embedder.embedTokens(tokens, destType); + + Tensor expected = Tensor.from("tensor<float>(x[7]):[-0.6178509, -0.8135831, 0.34416935, 0.3912577, -0.13068882, 2.5897025E-4, -0.18638384]"); + assertEquals(embedding, expected); + } + +} diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx b/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx Binary files differnew file mode 100644 index 00000000000..2101beec786 --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.py b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py new file mode 100644 index 00000000000..1028035d7c0 --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py @@ -0,0 +1,52 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import torch +import torch.onnx +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer + + +class TransformerModel(nn.Module): + def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2): + super(TransformerModel, self).__init__() + self.encoder = nn.Embedding(vocab_size, emb_size) + encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) + + def forward(self, tokens, attention_mask, token_type_ids): + src = self.encoder((tokens * attention_mask) + token_type_ids) + output = self.transformer_encoder(src) + return output + + +def main(): + vocabulary_size = 20 + embedding_size = 16 + hidden_dim_size = 32 + num_layers = 2 + num_heads = 2 + model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers) + + # Omit training - just export randomly initialized network + + tokens = torch.LongTensor([[1,2,3,4,5]]) + attention_mask = torch.LongTensor([[1,1,1,1,1]]) + token_type_ids = torch.LongTensor([[0,0,0,0,0]]) + torch.onnx.export(model, + (tokens, attention_mask, token_type_ids), + "dummy_transformer.onnx", + input_names = ["input_ids", "attention_mask", "token_type_ids"], + output_names = ["output_0"], + dynamic_axes = { + "input_ids": {0:"batch", 1:"tokens"}, + "attention_mask": {0:"batch", 1:"tokens"}, + "token_type_ids": {0:"batch", 1:"tokens"}, + "output_0": {0:"batch", 1:"tokens"}, + }, + opset_version=12) + + +if __name__ == "__main__": + main() + + diff --git a/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt b/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt new file mode 100644 index 00000000000..7dc0c6ecb6e --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt @@ -0,0 +1,35 @@ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +x +y +z +0 +1 +2 +3 +4 +5 +6 +7 +8 +9
\ No newline at end of file |