aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/simple/concat.py
blob: b77cf5decc1f5c2f69b52721f7c36f06398af2e4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
import onnx
import numpy as np
from onnx import helper, TensorProto

i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1])
j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1])
k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1])

output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])

node = onnx.helper.make_node(
    'Concat',
    inputs=['i', 'j', 'k'],
    outputs=['y'],
    axis=0,
)
graph_def = onnx.helper.make_graph(
    nodes = [node],
    name = 'concat_test',
    inputs = [i_type, j_type, k_type],
    outputs = [output_type]
)
model_def = helper.make_model(graph_def, producer_name='concat.py')
onnx.save(model_def, 'concat.onnx')