aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/llm/random_llm.py
blob: de580d639fed442e2609aecefe241a081f712385 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
import torch
import torch.onnx
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoder, TransformerDecoderLayer


class EncoderModel(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
        super(EncoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout, batch_first=batch_first)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)

    def forward(self, tokens, attention_mask):
        src = self.embedding(tokens * attention_mask)  # N, S, E
        output = self.transformer_encoder(src)
        return output


class DecoderModel(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
        super(DecoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        decoder_layers = nn.TransformerDecoderLayer(emb_size, num_heads, hidden_dim_size, batch_first=batch_first)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers)
        self.linear = nn.Linear(emb_size, vocab_size)

    def forward(self, tokens, attention_mask, encoder_hidden_state):
        tgt = self.embedding(tokens)  # N, T, E
        out = self.transformer_decoder(tgt, encoder_hidden_state, memory_mask=attention_mask)
        logits = self.linear(out)
        return logits


def main():
    vocabulary_size = 10000
    embedding_size = 8
    hidden_dim_size = 16
    num_heads = 1
    num_layers = 1

    encoder = EncoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)
    decoder = DecoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)

    # Omit training - just export randomly initialized network

    tokens = torch.LongTensor([[1, 2, 3, 4, 5]])
    attention_mask = torch.LongTensor([[1, 1, 1, 1, 1]])

    torch.onnx.export(encoder,
                      (tokens, attention_mask),
                      "random_encoder.onnx",
                      input_names=["input_ids", "attention_mask"],
                      output_names=["last_hidden_state"],
                      dynamic_axes={
                          "input_ids": {0: "batch", 1: "tokens"},
                          "attention_mask": {0: "batch", 1: "tokens"},
                          "last_hidden_state": {0: "batch", 1: "tokens"},
                      },
                      opset_version=12)

    last_hidden_state = encoder.forward(tokens, attention_mask)
    tokens = torch.LongTensor([[0]])  #1, 2]])

    torch.onnx.export(decoder,
                      (tokens, attention_mask.float(), last_hidden_state),
                      "random_decoder.onnx",
                      input_names=["input_ids", "encoder_attention_mask", "encoder_hidden_states"],
                      output_names=["logits"],
                      dynamic_axes={
                          "input_ids": {0: "batch", 1: "target_tokens"},
                          "encoder_attention_mask": {0: "batch", 1: "source_tokens"},
                          "encoder_hidden_states": {0: "batch", 1: "source_tokens"},
                          "logits": {0: "batch", 1: "target_tokens"},
                      },
                      opset_version=12)


if __name__ == "__main__":
    main()