summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/models/onnx/cast_bfloat16_float.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/models/onnx/cast_bfloat16_float.py')
-rwxr-xr-xmodel-integration/src/test/models/onnx/cast_bfloat16_float.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.py b/model-integration/src/test/models/onnx/cast_bfloat16_float.py
new file mode 100755
index 00000000000..14b05347262
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.py
@@ -0,0 +1,24 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.BFLOAT16, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['input1'],
+ ['output'],
+ to=TensorProto.FLOAT
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'cast',
+ [INPUT_1],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='cast_bfloat16_float.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'cast_bfloat16_float.onnx')