summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/onnx_wrapper/unstable_types.py
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-04-09 12:11:57 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-04-13 14:48:10 +0000
commit0ed7434c6cb6eba04d809b6fc60f1c8a0f94bf2d (patch)
tree6f4eefff81a452b6cc7685f479bc31fa08ad2ef1 /eval/src/tests/tensor/onnx_wrapper/unstable_types.py
parenta5f88e456dd105f1c47d2c42329a1c7f97cdde72 (diff)
onnx integration with unstable cell types
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/unstable_types.py')
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/unstable_types.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/unstable_types.py b/eval/src/tests/tensor/onnx_wrapper/unstable_types.py
new file mode 100755
index 00000000000..94a1975a560
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/unstable_types.py
@@ -0,0 +1,31 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+IN8 = helper.make_tensor_value_info('in8', TensorProto.INT8, [3])
+IN16 = helper.make_tensor_value_info('in16', TensorProto.BFLOAT16, [3])
+OUT8 = helper.make_tensor_value_info('out8', TensorProto.INT8, [3])
+OUT16 = helper.make_tensor_value_info('out16', TensorProto.BFLOAT16, [3])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['in8'],
+ ['out16'],
+ to=getattr(TensorProto, 'BFLOAT16')
+ ),
+ helper.make_node(
+ 'Cast',
+ ['in16'],
+ ['out8'],
+ to=getattr(TensorProto, 'INT8')
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'unstable_types',
+ [IN8, IN16],
+ [OUT8, OUT16],
+)
+model_def = helper.make_model(graph_def, producer_name='unstable_types.py', opset_imports=[onnx.OperatorSetIdProto(version=13)])
+onnx.save(model_def, 'unstable_types.onnx')