aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/simple/concat.py
blob: ca79f77a469b9701d1597e9376444785c4fde2fb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

import onnx
from onnx import helper, TensorProto

i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1])
j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1])
k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1])

output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])

node = onnx.helper.make_node(
    'Concat',
    inputs=['i', 'j', 'k'],
    outputs=['y'],
    axis=0,
)
graph_def = onnx.helper.make_graph(
    nodes = [node],
    name = 'concat_test',
    inputs = [i_type, j_type, k_type],
    outputs = [output_type]
)
model_def = helper.make_model(graph_def, producer_name='concat.py')
onnx.save(model_def, 'concat.onnx')