summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/simple/gather.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/models/onnx/simple/gather.py')
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/gather.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/model-integration/src/test/models/onnx/simple/gather.py b/model-integration/src/test/models/onnx/simple/gather.py
new file mode 100755
index 00000000000..63a2103fd86
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/gather.py
@@ -0,0 +1,23 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+import numpy as np
+from onnx import helper, TensorProto
+
+data_type = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3,2])
+indices_type = helper.make_tensor_value_info('indices', TensorProto.FLOAT, [2,2])
+output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2,2,2])
+
+node = onnx.helper.make_node(
+ 'Gather',
+ inputs=['data', 'indices'],
+ outputs=['y'],
+ axis=0,
+)
+graph_def = onnx.helper.make_graph(
+ nodes = [node],
+ name = 'gather_test',
+ inputs = [data_type, indices_type],
+ outputs = [output_type]
+)
+model_def = helper.make_model(graph_def, producer_name='gather.py')
+onnx.save(model_def, 'gather.onnx')