diff options
author | Jon Bratseth <bratseth@oath.com> | 2021-05-19 13:34:26 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-19 13:34:26 +0200 |
commit | 390a26e1a42486fefedef5468c86a781d1d833d1 (patch) | |
tree | 65117eab0d5e5cafee49a04b10d1c79f1e3431fa /model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py | |
parent | 75eca8ab11fcd74e08b50f0076970a5c61f1c63a (diff) | |
parent | a186020aa62214a714f24091b7928a159a55b166 (diff) |
Merge pull request #17895 from vespa-engine/lesters/onnx-rt-evaluator
Add ONNX-RT evaluator to model-integration module
Diffstat (limited to 'model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py')
-rwxr-xr-x | model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py b/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py new file mode 100755 index 00000000000..1296d84e180 --- /dev/null +++ b/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py @@ -0,0 +1,38 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import torch +import torch.onnx + + +class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.linear = torch.nn.Linear(in_features=3, out_features=1) + self.logistic = torch.nn.Sigmoid() + + def forward(self, vec): + return self.logistic(self.linear(vec)) + + +def main(): + model = MyModel() + + # Omit training - just export randomly initialized network + + data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]]) + torch.onnx.export(model, + data, + "one_layer.onnx", + input_names = ["input"], + output_names = ["output"], + dynamic_axes = { + "input": {0: "batch"}, + "output": {0: "batch"}, + }, + opset_version=12) + + +if __name__ == "__main__": + main() + + |