aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/simple/gather.py
blob: 9db15cb20c9a7fbfb77953ca6a7cf3a15b0b63c4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
import onnx
import numpy as np
from onnx import helper, TensorProto

data_type = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3,2])
indices_type = helper.make_tensor_value_info('indices', TensorProto.FLOAT, [2,2])
output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2,2,2])

node = onnx.helper.make_node(
    'Gather',
    inputs=['data', 'indices'],
    outputs=['y'],
    axis=0,
)
graph_def = onnx.helper.make_graph(
    nodes = [node],
    name = 'gather_test',
    inputs = [data_type, indices_type],
    outputs = [output_type]
)
model_def = helper.make_model(graph_def, producer_name='gather.py')
onnx.save(model_def, 'gather.onnx')