aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/transformer/dummy_transformer.py
blob: 1028035d7c0cda2878d45586f4dc09ae75b4d38f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

import torch
import torch.onnx
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class TransformerModel(nn.Module):
    def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2):
        super(TransformerModel, self).__init__()
        self.encoder = nn.Embedding(vocab_size, emb_size)
        encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)

    def forward(self, tokens, attention_mask, token_type_ids):
        src = self.encoder((tokens * attention_mask) + token_type_ids)
        output = self.transformer_encoder(src)
        return output


def main():
    vocabulary_size = 20
    embedding_size = 16
    hidden_dim_size = 32
    num_layers = 2
    num_heads = 2
    model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers)

    # Omit training - just export randomly initialized network

    tokens = torch.LongTensor([[1,2,3,4,5]])
    attention_mask = torch.LongTensor([[1,1,1,1,1]])
    token_type_ids = torch.LongTensor([[0,0,0,0,0]])
    torch.onnx.export(model,
                      (tokens, attention_mask, token_type_ids),
                      "dummy_transformer.onnx",
                      input_names = ["input_ids", "attention_mask", "token_type_ids"],
                      output_names = ["output_0"],
                      dynamic_axes = {
                          "input_ids": {0:"batch", 1:"tokens"},
                          "attention_mask": {0:"batch", 1:"tokens"},
                          "token_type_ids": {0:"batch", 1:"tokens"},
                          "output_0": {0:"batch", 1:"tokens"},
                      },
                      opset_version=12)


if __name__ == "__main__":
    main()