summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2021-05-19 13:34:26 +0200
committerGitHub <noreply@github.com>2021-05-19 13:34:26 +0200
commit390a26e1a42486fefedef5468c86a781d1d833d1 (patch)
tree65117eab0d5e5cafee49a04b10d1c79f1e3431fa /model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py
parent75eca8ab11fcd74e08b50f0076970a5c61f1c63a (diff)
parenta186020aa62214a714f24091b7928a159a55b166 (diff)
Merge pull request #17895 from vespa-engine/lesters/onnx-rt-evaluator
Add ONNX-RT evaluator to model-integration module
Diffstat (limited to 'model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py')
-rwxr-xr-xmodel-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py b/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py
new file mode 100755
index 00000000000..1296d84e180
--- /dev/null
+++ b/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py
@@ -0,0 +1,38 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import torch
+import torch.onnx
+
+
+class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.linear = torch.nn.Linear(in_features=3, out_features=1)
+ self.logistic = torch.nn.Sigmoid()
+
+ def forward(self, vec):
+ return self.logistic(self.linear(vec))
+
+
+def main():
+ model = MyModel()
+
+ # Omit training - just export randomly initialized network
+
+ data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]])
+ torch.onnx.export(model,
+ data,
+ "one_layer.onnx",
+ input_names = ["input"],
+ output_names = ["output"],
+ dynamic_axes = {
+ "input": {0: "batch"},
+ "output": {0: "batch"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+