aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/llm/random_llm.py
blob: 722906fc48b454e1943c114a8e37747d500545af (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.onnx
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoder, TransformerDecoderLayer


class EncoderModel(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
        super(EncoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout, batch_first=batch_first)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)

    def forward(self, tokens, attention_mask):
        src = self.embedding(tokens * attention_mask)  # N, S, E
        output = self.transformer_encoder(src)
        return output


class DecoderModel(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
        super(DecoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        decoder_layers = nn.TransformerDecoderLayer(emb_size, num_heads, hidden_dim_size, batch_first=batch_first)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers)
        self.linear = nn.Linear(emb_size, vocab_size)

    def forward(self, tokens, attention_mask, encoder_hidden_state):
        tgt = self.embedding(tokens)  # N, T, E
        out = self.transformer_decoder(tgt, encoder_hidden_state, memory_mask=attention_mask)
        logits = self.linear(out)
        return logits


def main():
    vocabulary_size = 10000
    embedding_size = 8
    hidden_dim_size = 16
    num_heads = 1
    num_layers = 1

    encoder = EncoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)
    decoder = DecoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)

    # Omit training - just export randomly initialized network

    tokens = torch.LongTensor([[1, 2, 3, 4, 5]])
    attention_mask = torch.LongTensor([[1, 1, 1, 1, 1]])

    torch.onnx.export(encoder,
                      (tokens, attention_mask),
                      "random_encoder.onnx",
                      input_names=["input_ids", "attention_mask"],
                      output_names=["last_hidden_state"],
                      dynamic_axes={
                          "input_ids": {0: "batch", 1: "tokens"},
                          "attention_mask": {0: "batch", 1: "tokens"},
                          "last_hidden_state": {0: "batch", 1: "tokens"},
                      },
                      opset_version=12)

    last_hidden_state = encoder.forward(tokens, attention_mask)
    tokens = torch.LongTensor([[0]])  #1, 2]])

    torch.onnx.export(decoder,
                      (tokens, attention_mask.float(), last_hidden_state),
                      "random_decoder.onnx",
                      input_names=["input_ids", "encoder_attention_mask", "encoder_hidden_states"],
                      output_names=["logits"],
                      dynamic_axes={
                          "input_ids": {0: "batch", 1: "target_tokens"},
                          "encoder_attention_mask": {0: "batch", 1: "source_tokens"},
                          "encoder_hidden_states": {0: "batch", 1: "source_tokens"},
                          "logits": {0: "batch", 1: "target_tokens"},
                      },
                      opset_version=12)


if __name__ == "__main__":
    main()