diff options
Diffstat (limited to 'model-integration/src/test/models/onnx/simple/gather.py')
-rwxr-xr-x | model-integration/src/test/models/onnx/simple/gather.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/model-integration/src/test/models/onnx/simple/gather.py b/model-integration/src/test/models/onnx/simple/gather.py new file mode 100755 index 00000000000..63a2103fd86 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/gather.py @@ -0,0 +1,23 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +import numpy as np +from onnx import helper, TensorProto + +data_type = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3,2]) +indices_type = helper.make_tensor_value_info('indices', TensorProto.FLOAT, [2,2]) +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2,2,2]) + +node = onnx.helper.make_node( + 'Gather', + inputs=['data', 'indices'], + outputs=['y'], + axis=0, +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'gather_test', + inputs = [data_type, indices_type], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='gather.py') +onnx.save(model_def, 'gather.onnx') |