diff options
author | Christophe Jolif <cjolif@gmail.com> | 2022-09-29 08:54:52 +0200 |
---|---|---|
committer | Christophe Jolif <cjolif@gmail.com> | 2022-10-20 13:59:59 +0200 |
commit | e39231d1f72080ebb6232f70bd5b388ba83232ec (patch) | |
tree | 7903558a30a297e7d8db1f49a8cbbed8f123393e /model-integration | |
parent | 8a91e259064b40f2f5fde5f8233c9892446d105e (diff) |
support models without tokenTypeIds (like DistilBERT)
Diffstat (limited to 'model-integration')
4 files changed, 97 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index d4a93999dff..002350ce3cf 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -67,8 +67,11 @@ public class BertBaseEmbedder implements Embedder { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); validateName(inputs, attentionMaskName, "input"); - validateName(inputs, tokenTypeIdsName, "input"); - + // some BERT inspired models such as DistilBERT do not have token_type_ids input + // one can explicitly declare this is such model by setting that config to empty string + if (!"".equals(tokenTypeIdsName)) { + validateName(inputs, tokenTypeIdsName, "input"); + } Map<String, TensorType> outputs = evaluator.getOutputInfo(); validateName(outputs, outputName, "output"); } @@ -102,9 +105,15 @@ public class BertBaseEmbedder implements Embedder { Tensor attentionMask = createAttentionMask(inputSequence); Tensor tokenTypeIds = createTokenTypeIds(inputSequence); - Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + Map<String, Tensor> inputs; + if (!"".equals(tokenTypeIdsName)) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index dbcae24b28f..73359736536 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -9,9 +9,11 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.lang.IllegalArgumentException; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assume.assumeTrue; public class BertBaseEmbedderTest { @@ -35,4 +37,38 @@ public class BertBaseEmbedderTest { assertEquals(embedding, expected); } + @Test + public void testEmbedderWithoutTokenTypeIdsName() { + String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; + String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx"; + assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + + BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); + builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerTokenTypeIds(""); + BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build()); + + TensorType destType = TensorType.fromSpec("tensor<float>(x[7])"); + List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer + Tensor embedding = embedder.embedTokens(tokens, destType); + + Tensor expected = Tensor.from("tensor<float>(x[7]):[0.10873623, 0.56411576, 0.6044973, -0.4819714, 0.7519982, -0.83261716, 0.30430704]"); + assertEquals(embedding, expected); + } + + @Test + public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() { + String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt"; + String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx"; + assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath)); + + BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder(); + builder.tokenizerVocab(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + // we did not configured BertBaseEmbedder to accept missing token type ids + // so we expect ctor to throw + assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); }); + } + } diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx Binary files differnew file mode 100644 index 00000000000..927f1c607c4 --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py new file mode 100644 index 00000000000..4c5f5ebe330 --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py @@ -0,0 +1,49 @@ +import torch +import torch.onnx +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer + + +class TransformerModel(nn.Module): + def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2): + super(TransformerModel, self).__init__() + self.encoder = nn.Embedding(vocab_size, emb_size) + encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) + + def forward(self, tokens, attention_mask): + src = self.encoder((tokens * attention_mask)) + output = self.transformer_encoder(src) + return output + + +def main(): + vocabulary_size = 20 + embedding_size = 16 + hidden_dim_size = 32 + num_layers = 2 + num_heads = 2 + model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers) + + # Omit training - just export randomly initialized network + + tokens = torch.LongTensor([[1,2,3,4,5]]) + attention_mask = torch.LongTensor([[1,1,1,1,1]]) + token_type_ids = torch.LongTensor([[0,0,0,0,0]]) + torch.onnx.export(model, + (tokens, attention_mask), + "dummy_transformer_without_type_ids.onnx", + input_names = ["input_ids", "attention_mask"], + output_names = ["output_0"], + dynamic_axes = { + "input_ids": {0:"batch", 1:"tokens"}, + "attention_mask": {0:"batch", 1:"tokens"}, + "output_0": {0:"batch", 1:"tokens"}, + }, + opset_version=12) + + +if __name__ == "__main__": + main() + + |