summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <havardpe@oath.com>2021-04-27 09:17:10 +0000
committerHÃ¥vard Pettersen <havardpe@oath.com>2021-04-27 09:17:10 +0000
commit2036e1fdd99504cb75ec3be1b0f3ba6bb1dc8fa3 (patch)
tree94c46e4cfb74be9231bc20b624ab75cb6360085f /eval
parent18946df1a0c58f591fad8a30c42209959691bb0c (diff)
what would onnx runtime do?
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/float_to_int8.onnx12
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/float_to_int8.py23
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp35
3 files changed, 70 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/float_to_int8.onnx b/eval/src/tests/tensor/onnx_wrapper/float_to_int8.onnx
new file mode 100644
index 00000000000..cde81d428bd
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/float_to_int8.onnx
@@ -0,0 +1,12 @@
+float_to_int8.py:P
+
+inout"Cast*
+to  float_to_int8Z
+in
+
+
+b
+out
+
+
+B \ No newline at end of file
diff --git a/eval/src/tests/tensor/onnx_wrapper/float_to_int8.py b/eval/src/tests/tensor/onnx_wrapper/float_to_int8.py
new file mode 100755
index 00000000000..2a8e47b3ffa
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/float_to_int8.py
@@ -0,0 +1,23 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+IN = helper.make_tensor_value_info('in', TensorProto.FLOAT, [7])
+OUT = helper.make_tensor_value_info('out', TensorProto.INT8, [7])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['in'],
+ ['out'],
+ to=getattr(TensorProto, 'INT8'),
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'float_to_int8',
+ [IN],
+ [OUT],
+)
+model_def = helper.make_model(graph_def, producer_name='float_to_int8.py', opset_imports=[onnx.OperatorSetIdProto(version=13)])
+onnx.save(model_def, 'float_to_int8.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 4fd527a7dfb..9b44dd7519e 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -26,6 +26,7 @@ std::string dynamic_model = source_dir + "/dynamic.onnx";
std::string int_types_model = source_dir + "/int_types.onnx";
std::string guess_batch_model = source_dir + "/guess_batch.onnx";
std::string unstable_types_model = source_dir + "/unstable_types.onnx";
+std::string float_to_int8_model = source_dir + "/float_to_int8.onnx";
void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
@@ -401,4 +402,38 @@ TEST(OnnxTest, converted_unstable_types) {
//-------------------------------------------------------------------------
}
+TEST(OnnxTest, inspect_float_to_int8_conversion) {
+ Onnx model(float_to_int8_model, Onnx::Optimize::ENABLE);
+ ASSERT_EQ(model.inputs().size(), 1);
+ ASSERT_EQ(model.outputs().size(), 1);
+
+ ValueType in_type = ValueType::from_spec("tensor<float>(a[7])");
+ const float my_nan = std::numeric_limits<float>::quiet_NaN();
+ const float my_inf = std::numeric_limits<float>::infinity();
+ std::vector<float> in_values({-my_inf, -142, -42, my_nan, 42, 142, my_inf});
+ DenseValueView in(in_type, TypedCells(in_values));
+
+ Onnx::WirePlanner planner;
+ EXPECT_TRUE(planner.bind_input_type(in_type, model.inputs()[0]));
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[7])");
+
+ auto wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &out = ctx.get_result(0);
+ EXPECT_EQ(out.type().to_spec(), "tensor<int8>(d0[7])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, in);
+ ctx.eval();
+ auto cells = out.cells();
+ ASSERT_EQ(cells.type, CellType::INT8);
+ ASSERT_EQ(cells.size, 7);
+ auto out_values = cells.typify<Int8Float>();
+ for (size_t i = 0; i < 7; ++i) {
+ fprintf(stderr, "convert(float->int8): '%g' -> '%d'\n",
+ in_values[i], out_values[i].get_bits());
+ }
+ //-------------------------------------------------------------------------
+}
+
GTEST_MAIN_RUN_ALL_TESTS()