summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/transformer/dummy_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/models/onnx/transformer/dummy_transformer.py')
-rw-r--r--model-integration/src/test/models/onnx/transformer/dummy_transformer.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.py b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py
new file mode 100644
index 00000000000..1028035d7c0
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py
@@ -0,0 +1,52 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import torch
+import torch.onnx
+import torch.nn as nn
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+
+
+class TransformerModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2):
+ super(TransformerModel, self).__init__()
+ self.encoder = nn.Embedding(vocab_size, emb_size)
+ encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout)
+ self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
+
+ def forward(self, tokens, attention_mask, token_type_ids):
+ src = self.encoder((tokens * attention_mask) + token_type_ids)
+ output = self.transformer_encoder(src)
+ return output
+
+
+def main():
+ vocabulary_size = 20
+ embedding_size = 16
+ hidden_dim_size = 32
+ num_layers = 2
+ num_heads = 2
+ model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers)
+
+ # Omit training - just export randomly initialized network
+
+ tokens = torch.LongTensor([[1,2,3,4,5]])
+ attention_mask = torch.LongTensor([[1,1,1,1,1]])
+ token_type_ids = torch.LongTensor([[0,0,0,0,0]])
+ torch.onnx.export(model,
+ (tokens, attention_mask, token_type_ids),
+ "dummy_transformer.onnx",
+ input_names = ["input_ids", "attention_mask", "token_type_ids"],
+ output_names = ["output_0"],
+ dynamic_axes = {
+ "input_ids": {0:"batch", 1:"tokens"},
+ "attention_mask": {0:"batch", 1:"tokens"},
+ "token_type_ids": {0:"batch", 1:"tokens"},
+ "output_0": {0:"batch", 1:"tokens"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+