aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py
blob: 3975c5ca34e55fe0b10fcbb9ca2cc2f0dd8d178e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

import torch
import torch.onnx


class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(in_features=3, out_features=1)
        self.logistic = torch.nn.Sigmoid()

    def forward(self, vec):
        return self.logistic(self.linear(vec))


def main():
    model = MyModel()

    # Omit training - just export randomly initialized network

    data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]])
    torch.onnx.export(model,
                      data,
                      "one_layer.onnx",
                      input_names = ["input"],
                      output_names = ["output"],
                      dynamic_axes = {
                          "input": {0: "batch"},
                          "output": {0: "batch"},
                      },
                      opset_version=12)


if __name__ == "__main__":
    main()