diff options
7 files changed, 0 insertions, 320 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index d064a3ff709..41139394690 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -45,18 +45,6 @@ <version>${project.version}</version> <scope>provided</scope> </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>linguistics</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>linguistics-components</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> <dependency> <groupId>com.google.guava</groupId> diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java deleted file mode 100644 index 42e3d653359..00000000000 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ /dev/null @@ -1,165 +0,0 @@ -package ai.vespa.embedding; - -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; -import com.yahoo.component.annotation.Inject; -import com.yahoo.language.process.Embedder; -import com.yahoo.language.wordpiece.WordPieceEmbedder; -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - - -/** - * A BERT Base compatible embedder. This embedder uses a WordPiece embedder to - * produce a token sequence that is input to a transformer model. A BERT base - * compatible transformer model must have three inputs: - * - * - A token sequence (input_ids) - * - An attention mask (attention_mask) - * - Token types for cross encoding (token_type_ids) - * - * See bert-base-embedder.def for configurable parameters. - * - * @author lesters - */ -public class BertBaseEmbedder implements Embedder { - - private final static int TOKEN_CLS = 101; // [CLS] - private final static int TOKEN_SEP = 102; // [SEP] - - private final int maxTokens; - private final String inputIdsName; - private final String attentionMaskName; - private final String tokenTypeIdsName; - private final String outputName; - private final String poolingStrategy; - - private final WordPieceEmbedder tokenizer; - private final OnnxEvaluator evaluator; - - @Inject - public BertBaseEmbedder(BertBaseEmbedderConfig config) { - maxTokens = config.transformerMaxTokens(); - inputIdsName = config.transformerInputIds(); - attentionMaskName = config.transformerAttentionMask(); - tokenTypeIdsName = config.transformerTokenTypeIds(); - outputName = config.transformerOutput(); - poolingStrategy = config.poolingStrategy().toString(); - - OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); - options.setExecutionMode(config.onnxExecutionMode().toString()); - options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads())); - options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads())); - - // Todo: use either file or url - tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build(); - evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options); - - validateModel(); - } - - private void validateModel() { - Map<String, TensorType> inputs = evaluator.getInputInfo(); - validateName(inputs, inputIdsName, "input"); - validateName(inputs, attentionMaskName, "input"); - validateName(inputs, tokenTypeIdsName, "input"); - - Map<String, TensorType> outputs = evaluator.getOutputInfo(); - validateName(outputs, outputName, "output"); - } - - private void validateName(Map<String, TensorType> types, String name, String type) { - if ( ! types.containsKey(name)) { - throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + - "Model contains: " + String.join(",", types.keySet())); - } - } - - @Override - public List<Integer> embed(String text, Context context) { - return tokenizer.embed(text, context); - } - - @Override - public Tensor embed(String text, Context context, TensorType type) { - if (type.dimensions().size() != 1) { - throw new IllegalArgumentException("Error in embedding to type '" + type + "': should only have one dimension."); - } - if (!type.dimensions().get(0).isIndexed()) { - throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed."); - } - List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens); - return embedTokens(tokens, type); - } - - Tensor embedTokens(List<Integer> tokens, TensorType type) { - Tensor inputSequence = createTensorRepresentation(tokens, "d1"); - Tensor attentionMask = createAttentionMask(inputSequence); - Tensor tokenTypeIds = createTokenTypeIds(inputSequence); - - Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0"), - tokenTypeIdsName, tokenTypeIds.expand("d0")); - Map<String, Tensor> outputs = evaluator.evaluate(inputs); - - Tensor tokenEmbeddings = outputs.get(outputName); - - Tensor.Builder builder = Tensor.Builder.of(type); - if (poolingStrategy.equals("mean")) { // average over tokens - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - } else { // CLS - use first token - for (int i = 0; i < type.dimensions().get(0).size().get(); i++) { - builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i); - } - } - return builder.build(); - } - - private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) { - List<Integer> tokens = new ArrayList<>(); - tokens.add(TOKEN_CLS); - tokens.addAll(embed(text, context)); - tokens.add(TOKEN_SEP); - if (tokens.size() > maxLength) { - tokens = tokens.subList(0, maxLength-1); - tokens.add(TOKEN_SEP); - } - return tokens; - } - - private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) { - int size = input.size(); - TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); - for (int i = 0; i < size; ++i) { - builder.cell(input.get(i), i); - } - return builder.build(); - } - - private static Tensor createAttentionMask(Tensor d) { - return d.map((x) -> x > 0 ? 1:0); - } - - private static Tensor createTokenTypeIds(Tensor d) { - return d.map((x) -> x > 0 ? 0:0); - } - - private int modifyThreadCount(int numThreads) { - if (numThreads >= 0) - return numThreads; - return Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / (-1 * numThreads))); - } - -} diff --git a/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def deleted file mode 100644 index 7e3ff151466..00000000000 --- a/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def +++ /dev/null @@ -1,27 +0,0 @@ -package=ai.vespa.embedding - -# Transformer model settings -transformerModelUrl url - -# Max length of token sequence model can handle -transformerMaxTokens int default=384 - -# Pooling strategy -poolingStrategy enum { cls, mean } default=mean - -# Input names -transformerInputIds string default=input_ids -transformerAttentionMask string default=attention_mask -transformerTokenTypeIds string default=token_type_ids - -# Output name -transformerOutput string default=output_0 - -# Settings for ONNX model evaluation -onnxExecutionMode enum { parallel, sequential } default=sequential -onnxInterOpThreads int default=1 -onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n - -# Settings for wordpiece tokenizer -tokenizerVocabUrl url - diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java deleted file mode 100644 index 519f24795ca..00000000000 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ /dev/null @@ -1,29 +0,0 @@ -package ai.vespa.embedding; - -import com.yahoo.config.UrlReference; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.junit.Test; - -import java.util.List; - -import static org.junit.Assert.assertEquals; - -public class BertBaseEmbedderTest { - - @Test - public void testEmbedder() { - BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); - builder.tokenizerVocabUrl(new UrlReference("src/test/models/onnx/transformer/dummy_vocab.txt")); - builder.transformerModelUrl(new UrlReference("src/test/models/onnx/transformer/dummy_transformer.onnx")); - BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); - - TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); - List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer - Tensor embedding = embedder.embedTokens(tokens, destType); - - Tensor expected = Tensor.from("tensor<float>(x[7]):[-0.6178509, -0.8135831, 0.34416935, 0.3912577, -0.13068882, 2.5897025E-4, -0.18638384]"); - assertEquals(embedding, expected); - } - -} diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx b/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx Binary files differdeleted file mode 100644 index 2101beec786..00000000000 --- a/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx +++ /dev/null diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.py b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py deleted file mode 100644 index 1028035d7c0..00000000000 --- a/model-integration/src/test/models/onnx/transformer/dummy_transformer.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import torch -import torch.onnx -import torch.nn as nn -from torch.nn import TransformerEncoder, TransformerEncoderLayer - - -class TransformerModel(nn.Module): - def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2): - super(TransformerModel, self).__init__() - self.encoder = nn.Embedding(vocab_size, emb_size) - encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) - - def forward(self, tokens, attention_mask, token_type_ids): - src = self.encoder((tokens * attention_mask) + token_type_ids) - output = self.transformer_encoder(src) - return output - - -def main(): - vocabulary_size = 20 - embedding_size = 16 - hidden_dim_size = 32 - num_layers = 2 - num_heads = 2 - model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers) - - # Omit training - just export randomly initialized network - - tokens = torch.LongTensor([[1,2,3,4,5]]) - attention_mask = torch.LongTensor([[1,1,1,1,1]]) - token_type_ids = torch.LongTensor([[0,0,0,0,0]]) - torch.onnx.export(model, - (tokens, attention_mask, token_type_ids), - "dummy_transformer.onnx", - input_names = ["input_ids", "attention_mask", "token_type_ids"], - output_names = ["output_0"], - dynamic_axes = { - "input_ids": {0:"batch", 1:"tokens"}, - "attention_mask": {0:"batch", 1:"tokens"}, - "token_type_ids": {0:"batch", 1:"tokens"}, - "output_0": {0:"batch", 1:"tokens"}, - }, - opset_version=12) - - -if __name__ == "__main__": - main() - - diff --git a/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt b/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt deleted file mode 100644 index 7dc0c6ecb6e..00000000000 --- a/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt +++ /dev/null @@ -1,35 +0,0 @@ -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -x -y -z -0 -1 -2 -3 -4 -5 -6 -7 -8 -9
\ No newline at end of file |