From 0ed7434c6cb6eba04d809b6fc60f1c8a0f94bf2d Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Fri, 9 Apr 2021 12:11:57 +0000 Subject: onnx integration with unstable cell types --- eval/src/tests/tensor/onnx_wrapper/dynamic.py | 2 +- eval/src/tests/tensor/onnx_wrapper/guess_batch.py | 2 +- eval/src/tests/tensor/onnx_wrapper/int_types.py | 2 +- .../tensor/onnx_wrapper/onnx_wrapper_test.cpp | 95 ++++++++++++++++++++++ eval/src/tests/tensor/onnx_wrapper/simple.py | 3 +- .../tests/tensor/onnx_wrapper/unstable_types.onnx | 23 ++++++ .../tests/tensor/onnx_wrapper/unstable_types.py | 31 +++++++ eval/src/vespa/eval/onnx/onnx_wrapper.cpp | 63 +++++++------- eval/src/vespa/eval/onnx/onnx_wrapper.h | 2 +- 9 files changed, 188 insertions(+), 35 deletions(-) create mode 100644 eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx create mode 100755 eval/src/tests/tensor/onnx_wrapper/unstable_types.py (limited to 'eval') diff --git a/eval/src/tests/tensor/onnx_wrapper/dynamic.py b/eval/src/tests/tensor/onnx_wrapper/dynamic.py index d098324fae8..cdf59c4f700 100755 --- a/eval/src/tests/tensor/onnx_wrapper/dynamic.py +++ b/eval/src/tests/tensor/onnx_wrapper/dynamic.py @@ -35,5 +35,5 @@ graph_def = helper.make_graph( ], [OUTPUT], ) -model_def = helper.make_model(graph_def, producer_name='dynamic.py') +model_def = helper.make_model(graph_def, producer_name='dynamic.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'dynamic.onnx') diff --git a/eval/src/tests/tensor/onnx_wrapper/guess_batch.py b/eval/src/tests/tensor/onnx_wrapper/guess_batch.py index c43448c58a7..63b2c84e934 100755 --- a/eval/src/tests/tensor/onnx_wrapper/guess_batch.py +++ b/eval/src/tests/tensor/onnx_wrapper/guess_batch.py @@ -22,5 +22,5 @@ graph_def = helper.make_graph( ], [OUT], ) -model_def = helper.make_model(graph_def, producer_name='guess_batch.py') +model_def = helper.make_model(graph_def, producer_name='guess_batch.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'guess_batch.onnx') diff --git a/eval/src/tests/tensor/onnx_wrapper/int_types.py b/eval/src/tests/tensor/onnx_wrapper/int_types.py index cd82bfd44b5..e5adf035e4b 100755 --- a/eval/src/tests/tensor/onnx_wrapper/int_types.py +++ b/eval/src/tests/tensor/onnx_wrapper/int_types.py @@ -29,5 +29,5 @@ graph_def = helper.make_graph( ], [OUTPUT], ) -model_def = helper.make_model(graph_def, producer_name='int_types.py') +model_def = helper.make_model(graph_def, producer_name='int_types.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'int_types.onnx') diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp index b474d2458b9..4fd527a7dfb 100644 --- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp +++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp @@ -1,12 +1,16 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include +#include #include +#include #include #include using namespace vespalib::eval; +using vespalib::BFloat16; + using vespalib::make_string_short::fmt; using TensorInfo = Onnx::TensorInfo; using ElementType = Onnx::ElementType; @@ -21,6 +25,7 @@ std::string simple_model = source_dir + "/simple.onnx"; std::string dynamic_model = source_dir + "/dynamic.onnx"; std::string int_types_model = source_dir + "/int_types.onnx"; std::string guess_batch_model = source_dir + "/guess_batch.onnx"; +std::string unstable_types_model = source_dir + "/unstable_types.onnx"; void dump_info(const char *ctx, const std::vector &info) { fprintf(stderr, "%s:\n", ctx); @@ -306,4 +311,94 @@ TEST(OnnxTest, we_guess_batch_dimension_size_when_inference_fails) { //------------------------------------------------------------------------- } +TEST(OnnxTest, zero_copy_unstable_types) { + Onnx model(unstable_types_model, Onnx::Optimize::ENABLE); + ASSERT_EQ(model.inputs().size(), 2); + ASSERT_EQ(model.outputs().size(), 2); + + ValueType in8_type = ValueType::from_spec("tensor(a[3])"); + std::vector in8_values({1.0, 2.0, 3.0}); + DenseValueView in8(in8_type, TypedCells(in8_values)); + + ValueType in16_type = ValueType::from_spec("tensor(a[3])"); + std::vector in16_values({4.0, 5.0, 6.0}); + DenseValueView in16(in16_type, TypedCells(in16_values)); + + Onnx::WirePlanner planner; + EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0])); + EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1])); + EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor(d0[3])"); + EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor(d0[3])"); + + auto wire_info = planner.get_wire_info(model); + Onnx::EvalContext ctx(model, wire_info); + + const Value &out8 = ctx.get_result(0); + const Value &out16 = ctx.get_result(1); + EXPECT_EQ(out8.type().to_spec(), "tensor(d0[3])"); + EXPECT_EQ(out16.type().to_spec(), "tensor(d0[3])"); + //------------------------------------------------------------------------- + ctx.bind_param(0, in8); + ctx.bind_param(1, in16); + ctx.eval(); + auto cells8 = out8.cells(); + auto cells16 = out16.cells(); + ASSERT_EQ(cells8.type, CellType::INT8); + ASSERT_EQ(cells16.type, CellType::BFLOAT16); + ASSERT_EQ(cells8.size, 3); + ASSERT_EQ(cells16.size, 3); + EXPECT_EQ(cells8.typify()[0], 4.0); + EXPECT_EQ(cells8.typify()[1], 5.0); + EXPECT_EQ(cells8.typify()[2], 6.0); + EXPECT_EQ(cells16.typify()[0], 1.0); + EXPECT_EQ(cells16.typify()[1], 2.0); + EXPECT_EQ(cells16.typify()[2], 3.0); + //------------------------------------------------------------------------- +} + +TEST(OnnxTest, converted_unstable_types) { + Onnx model(unstable_types_model, Onnx::Optimize::ENABLE); + ASSERT_EQ(model.inputs().size(), 2); + ASSERT_EQ(model.outputs().size(), 2); + + ValueType in8_type = ValueType::from_spec("tensor(a[3])"); + std::vector in8_values({1.0, 2.0, 3.0}); + DenseValueView in8(in8_type, TypedCells(in8_values)); + + ValueType in16_type = ValueType::from_spec("tensor(a[3])"); + std::vector in16_values({4.0, 5.0, 6.0}); + DenseValueView in16(in16_type, TypedCells(in16_values)); + + Onnx::WirePlanner planner; + EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0])); + EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1])); + EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor(d0[3])"); + EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor(d0[3])"); + + auto wire_info = planner.get_wire_info(model); + Onnx::EvalContext ctx(model, wire_info); + + const Value &out8 = ctx.get_result(0); + const Value &out16 = ctx.get_result(1); + EXPECT_EQ(out8.type().to_spec(), "tensor(d0[3])"); + EXPECT_EQ(out16.type().to_spec(), "tensor(d0[3])"); + //------------------------------------------------------------------------- + ctx.bind_param(0, in8); + ctx.bind_param(1, in16); + ctx.eval(); + auto cells8 = out8.cells(); + auto cells16 = out16.cells(); + ASSERT_EQ(cells8.type, CellType::INT8); + ASSERT_EQ(cells16.type, CellType::BFLOAT16); + ASSERT_EQ(cells8.size, 3); + ASSERT_EQ(cells16.size, 3); + EXPECT_EQ(cells8.typify()[0], 4.0); + EXPECT_EQ(cells8.typify()[1], 5.0); + EXPECT_EQ(cells8.typify()[2], 6.0); + EXPECT_EQ(cells16.typify()[0], 1.0); + EXPECT_EQ(cells16.typify()[1], 2.0); + EXPECT_EQ(cells16.typify()[2], 3.0); + //------------------------------------------------------------------------- +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/tensor/onnx_wrapper/simple.py b/eval/src/tests/tensor/onnx_wrapper/simple.py index a3cd2425d58..c8db58b5ebb 100755 --- a/eval/src/tests/tensor/onnx_wrapper/simple.py +++ b/eval/src/tests/tensor/onnx_wrapper/simple.py @@ -29,5 +29,6 @@ graph_def = helper.make_graph( ], [OUTPUT], ) -model_def = helper.make_model(graph_def, producer_name='simple.py') + +model_def = helper.make_model(graph_def, producer_name='simple.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'simple.onnx') diff --git a/eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx b/eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx new file mode 100644 index 00000000000..b833086ddd0 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx @@ -0,0 +1,23 @@ +unstable_types.py:ž + +in8out16"Cast* +to  + +in16out8"Cast* +to unstable_typesZ +in8 + + +Z +in16 + + +b +out8 + + +b +out16 + + +B \ No newline at end of file diff --git a/eval/src/tests/tensor/onnx_wrapper/unstable_types.py b/eval/src/tests/tensor/onnx_wrapper/unstable_types.py new file mode 100755 index 00000000000..94a1975a560 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/unstable_types.py @@ -0,0 +1,31 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +IN8 = helper.make_tensor_value_info('in8', TensorProto.INT8, [3]) +IN16 = helper.make_tensor_value_info('in16', TensorProto.BFLOAT16, [3]) +OUT8 = helper.make_tensor_value_info('out8', TensorProto.INT8, [3]) +OUT16 = helper.make_tensor_value_info('out16', TensorProto.BFLOAT16, [3]) + +nodes = [ + helper.make_node( + 'Cast', + ['in8'], + ['out16'], + to=getattr(TensorProto, 'BFLOAT16') + ), + helper.make_node( + 'Cast', + ['in16'], + ['out8'], + to=getattr(TensorProto, 'INT8') + ), +] +graph_def = helper.make_graph( + nodes, + 'unstable_types', + [IN8, IN16], + [OUT8, OUT16], +) +model_def = helper.make_model(graph_def, producer_name='unstable_types.py', opset_imports=[onnx.OperatorSetIdProto(version=13)]) +onnx.save(model_def, 'unstable_types.onnx') diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp index e9758f2ddc8..3a593f491d8 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp @@ -38,16 +38,17 @@ struct TypifyOnnxElementType { template using Result = TypifyResultType; template static decltype(auto) resolve(Onnx::ElementType value, F &&f) { switch(value) { - case Onnx::ElementType::INT8: return f(Result()); - case Onnx::ElementType::INT16: return f(Result()); - case Onnx::ElementType::INT32: return f(Result()); - case Onnx::ElementType::INT64: return f(Result()); - case Onnx::ElementType::UINT8: return f(Result()); - case Onnx::ElementType::UINT16: return f(Result()); - case Onnx::ElementType::UINT32: return f(Result()); - case Onnx::ElementType::UINT64: return f(Result()); - case Onnx::ElementType::FLOAT: return f(Result()); - case Onnx::ElementType::DOUBLE: return f(Result()); + case Onnx::ElementType::INT8: return f(Result()); + case Onnx::ElementType::INT16: return f(Result()); + case Onnx::ElementType::INT32: return f(Result()); + case Onnx::ElementType::INT64: return f(Result()); + case Onnx::ElementType::UINT8: return f(Result()); + case Onnx::ElementType::UINT16: return f(Result()); + case Onnx::ElementType::UINT32: return f(Result()); + case Onnx::ElementType::UINT64: return f(Result()); + case Onnx::ElementType::BFLOAT16: return f(Result()); + case Onnx::ElementType::FLOAT: return f(Result()); + case Onnx::ElementType::DOUBLE: return f(Result()); } abort(); } @@ -118,32 +119,34 @@ auto convert_optimize(Onnx::Optimize optimize) { CellType to_cell_type(Onnx::ElementType type) { switch (type) { - case Onnx::ElementType::INT8: [[fallthrough]]; - case Onnx::ElementType::INT16: [[fallthrough]]; - case Onnx::ElementType::UINT8: [[fallthrough]]; - case Onnx::ElementType::UINT16: [[fallthrough]]; - case Onnx::ElementType::FLOAT: return CellType::FLOAT; - case Onnx::ElementType::INT32: [[fallthrough]]; - case Onnx::ElementType::INT64: [[fallthrough]]; - case Onnx::ElementType::UINT32: [[fallthrough]]; - case Onnx::ElementType::UINT64: [[fallthrough]]; - case Onnx::ElementType::DOUBLE: return CellType::DOUBLE; + case Onnx::ElementType::INT8: return CellType::INT8; + case Onnx::ElementType::BFLOAT16: return CellType::BFLOAT16; + case Onnx::ElementType::UINT8: [[fallthrough]]; + case Onnx::ElementType::INT16: [[fallthrough]]; + case Onnx::ElementType::UINT16: [[fallthrough]]; + case Onnx::ElementType::FLOAT: return CellType::FLOAT; + case Onnx::ElementType::INT32: [[fallthrough]]; + case Onnx::ElementType::INT64: [[fallthrough]]; + case Onnx::ElementType::UINT32: [[fallthrough]]; + case Onnx::ElementType::UINT64: [[fallthrough]]; + case Onnx::ElementType::DOUBLE: return CellType::DOUBLE; } abort(); } Onnx::ElementType make_element_type(ONNXTensorElementDataType element_type) { switch (element_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return Onnx::ElementType::INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return Onnx::ElementType::INT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Onnx::ElementType::INT32; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Onnx::ElementType::INT64; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return Onnx::ElementType::UINT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return Onnx::ElementType::UINT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: return Onnx::ElementType::UINT32; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: return Onnx::ElementType::UINT64; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Onnx::ElementType::FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return Onnx::ElementType::DOUBLE; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return Onnx::ElementType::INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return Onnx::ElementType::INT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Onnx::ElementType::INT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Onnx::ElementType::INT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return Onnx::ElementType::UINT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return Onnx::ElementType::UINT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: return Onnx::ElementType::UINT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: return Onnx::ElementType::UINT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return Onnx::ElementType::BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Onnx::ElementType::FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return Onnx::ElementType::DOUBLE; default: throw Ort::Exception(fmt("[onnx wrapper] unsupported element type: %d", element_type), ORT_FAIL); } diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.h b/eval/src/vespa/eval/onnx/onnx_wrapper.h index 68c31f04cdc..1f36d576c33 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.h +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.h @@ -51,7 +51,7 @@ public: }; // supported onnx element types - enum class ElementType { INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, FLOAT, DOUBLE }; + enum class ElementType { INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, BFLOAT16, FLOAT, DOUBLE }; // information about a single input or output tensor struct TensorInfo { -- cgit v1.2.3