aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/onnx_wrapper/probe_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/probe_model.py')
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/probe_model.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/probe_model.py b/eval/src/tests/tensor/onnx_wrapper/probe_model.py
new file mode 100755
index 00000000000..529fa23b2b1
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/probe_model.py
@@ -0,0 +1,35 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+IN1 = helper.make_tensor_value_info('in1', TensorProto.FLOAT, [-1, 'inner'])
+IN2 = helper.make_tensor_value_info('in2', TensorProto.FLOAT, ['outer', -1])
+OUT1 = helper.make_tensor_value_info('out1', TensorProto.FLOAT, [-1, 'inner'])
+OUT2 = helper.make_tensor_value_info('out2', TensorProto.FLOAT, ['outer', -1])
+OUT3 = helper.make_tensor_value_info('out3', TensorProto.FLOAT, [-1, -1])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['in1', 'in2'],
+ ['out1'],
+ ),
+ helper.make_node(
+ 'Sub',
+ ['in1', 'in2'],
+ ['out2'],
+ ),
+ helper.make_node(
+ 'Mul',
+ ['in1', 'in2'],
+ ['out3'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'probe_model',
+ [IN1, IN2],
+ [OUT1, OUT2, OUT3],
+)
+model_def = helper.make_model(graph_def, producer_name='probe_model.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'probe_model.onnx')