aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <havardpe@oath.com>2021-04-09 12:11:57 +0000
committerHÃ¥vard Pettersen <havardpe@oath.com>2021-04-13 14:48:10 +0000
commit0ed7434c6cb6eba04d809b6fc60f1c8a0f94bf2d (patch)
tree6f4eefff81a452b6cc7685f479bc31fa08ad2ef1
parenta5f88e456dd105f1c47d2c42329a1c7f97cdde72 (diff)
onnx integration with unstable cell types
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/dynamic.py2
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/guess_batch.py2
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/int_types.py2
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp95
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/simple.py3
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx23
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/unstable_types.py31
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.cpp63
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.h2
9 files changed, 188 insertions, 35 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/dynamic.py b/eval/src/tests/tensor/onnx_wrapper/dynamic.py
index d098324fae8..cdf59c4f700 100755
--- a/eval/src/tests/tensor/onnx_wrapper/dynamic.py
+++ b/eval/src/tests/tensor/onnx_wrapper/dynamic.py
@@ -35,5 +35,5 @@ graph_def = helper.make_graph(
],
[OUTPUT],
)
-model_def = helper.make_model(graph_def, producer_name='dynamic.py')
+model_def = helper.make_model(graph_def, producer_name='dynamic.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
onnx.save(model_def, 'dynamic.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/guess_batch.py b/eval/src/tests/tensor/onnx_wrapper/guess_batch.py
index c43448c58a7..63b2c84e934 100755
--- a/eval/src/tests/tensor/onnx_wrapper/guess_batch.py
+++ b/eval/src/tests/tensor/onnx_wrapper/guess_batch.py
@@ -22,5 +22,5 @@ graph_def = helper.make_graph(
],
[OUT],
)
-model_def = helper.make_model(graph_def, producer_name='guess_batch.py')
+model_def = helper.make_model(graph_def, producer_name='guess_batch.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
onnx.save(model_def, 'guess_batch.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/int_types.py b/eval/src/tests/tensor/onnx_wrapper/int_types.py
index cd82bfd44b5..e5adf035e4b 100755
--- a/eval/src/tests/tensor/onnx_wrapper/int_types.py
+++ b/eval/src/tests/tensor/onnx_wrapper/int_types.py
@@ -29,5 +29,5 @@ graph_def = helper.make_graph(
],
[OUTPUT],
)
-model_def = helper.make_model(graph_def, producer_name='int_types.py')
+model_def = helper.make_model(graph_def, producer_name='int_types.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
onnx.save(model_def, 'int_types.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index b474d2458b9..4fd527a7dfb 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -1,12 +1,16 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/eval/eval/tensor_spec.h>
+#include <vespa/eval/eval/int8float.h>
#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/vespalib/util/bfloat16.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
using namespace vespalib::eval;
+using vespalib::BFloat16;
+
using vespalib::make_string_short::fmt;
using TensorInfo = Onnx::TensorInfo;
using ElementType = Onnx::ElementType;
@@ -21,6 +25,7 @@ std::string simple_model = source_dir + "/simple.onnx";
std::string dynamic_model = source_dir + "/dynamic.onnx";
std::string int_types_model = source_dir + "/int_types.onnx";
std::string guess_batch_model = source_dir + "/guess_batch.onnx";
+std::string unstable_types_model = source_dir + "/unstable_types.onnx";
void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
@@ -306,4 +311,94 @@ TEST(OnnxTest, we_guess_batch_dimension_size_when_inference_fails) {
//-------------------------------------------------------------------------
}
+TEST(OnnxTest, zero_copy_unstable_types) {
+ Onnx model(unstable_types_model, Onnx::Optimize::ENABLE);
+ ASSERT_EQ(model.inputs().size(), 2);
+ ASSERT_EQ(model.outputs().size(), 2);
+
+ ValueType in8_type = ValueType::from_spec("tensor<int8>(a[3])");
+ std::vector<Int8Float> in8_values({1.0, 2.0, 3.0});
+ DenseValueView in8(in8_type, TypedCells(in8_values));
+
+ ValueType in16_type = ValueType::from_spec("tensor<bfloat16>(a[3])");
+ std::vector<BFloat16> in16_values({4.0, 5.0, 6.0});
+ DenseValueView in16(in16_type, TypedCells(in16_values));
+
+ Onnx::WirePlanner planner;
+ EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0]));
+ EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1]));
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor<bfloat16>(d0[3])");
+
+ auto wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &out8 = ctx.get_result(0);
+ const Value &out16 = ctx.get_result(1);
+ EXPECT_EQ(out8.type().to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(out16.type().to_spec(), "tensor<bfloat16>(d0[3])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, in8);
+ ctx.bind_param(1, in16);
+ ctx.eval();
+ auto cells8 = out8.cells();
+ auto cells16 = out16.cells();
+ ASSERT_EQ(cells8.type, CellType::INT8);
+ ASSERT_EQ(cells16.type, CellType::BFLOAT16);
+ ASSERT_EQ(cells8.size, 3);
+ ASSERT_EQ(cells16.size, 3);
+ EXPECT_EQ(cells8.typify<Int8Float>()[0], 4.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[1], 5.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[2], 6.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[0], 1.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[1], 2.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[2], 3.0);
+ //-------------------------------------------------------------------------
+}
+
+TEST(OnnxTest, converted_unstable_types) {
+ Onnx model(unstable_types_model, Onnx::Optimize::ENABLE);
+ ASSERT_EQ(model.inputs().size(), 2);
+ ASSERT_EQ(model.outputs().size(), 2);
+
+ ValueType in8_type = ValueType::from_spec("tensor<float>(a[3])");
+ std::vector<float> in8_values({1.0, 2.0, 3.0});
+ DenseValueView in8(in8_type, TypedCells(in8_values));
+
+ ValueType in16_type = ValueType::from_spec("tensor<float>(a[3])");
+ std::vector<float> in16_values({4.0, 5.0, 6.0});
+ DenseValueView in16(in16_type, TypedCells(in16_values));
+
+ Onnx::WirePlanner planner;
+ EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0]));
+ EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1]));
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor<bfloat16>(d0[3])");
+
+ auto wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &out8 = ctx.get_result(0);
+ const Value &out16 = ctx.get_result(1);
+ EXPECT_EQ(out8.type().to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(out16.type().to_spec(), "tensor<bfloat16>(d0[3])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, in8);
+ ctx.bind_param(1, in16);
+ ctx.eval();
+ auto cells8 = out8.cells();
+ auto cells16 = out16.cells();
+ ASSERT_EQ(cells8.type, CellType::INT8);
+ ASSERT_EQ(cells16.type, CellType::BFLOAT16);
+ ASSERT_EQ(cells8.size, 3);
+ ASSERT_EQ(cells16.size, 3);
+ EXPECT_EQ(cells8.typify<Int8Float>()[0], 4.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[1], 5.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[2], 6.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[0], 1.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[1], 2.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[2], 3.0);
+ //-------------------------------------------------------------------------
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/tensor/onnx_wrapper/simple.py b/eval/src/tests/tensor/onnx_wrapper/simple.py
index a3cd2425d58..c8db58b5ebb 100755
--- a/eval/src/tests/tensor/onnx_wrapper/simple.py
+++ b/eval/src/tests/tensor/onnx_wrapper/simple.py
@@ -29,5 +29,6 @@ graph_def = helper.make_graph(
],
[OUTPUT],
)
-model_def = helper.make_model(graph_def, producer_name='simple.py')
+
+model_def = helper.make_model(graph_def, producer_name='simple.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
onnx.save(model_def, 'simple.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx b/eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx
new file mode 100644
index 00000000000..b833086ddd0
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/unstable_types.onnx
@@ -0,0 +1,23 @@
+unstable_types.py:ž
+
+in8out16"Cast*
+to 
+
+in16out8"Cast*
+to unstable_typesZ
+in8
+
+
+Z
+in16
+
+
+b
+out8
+
+
+b
+out16
+
+
+B \ No newline at end of file
diff --git a/eval/src/tests/tensor/onnx_wrapper/unstable_types.py b/eval/src/tests/tensor/onnx_wrapper/unstable_types.py
new file mode 100755
index 00000000000..94a1975a560
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/unstable_types.py
@@ -0,0 +1,31 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+IN8 = helper.make_tensor_value_info('in8', TensorProto.INT8, [3])
+IN16 = helper.make_tensor_value_info('in16', TensorProto.BFLOAT16, [3])
+OUT8 = helper.make_tensor_value_info('out8', TensorProto.INT8, [3])
+OUT16 = helper.make_tensor_value_info('out16', TensorProto.BFLOAT16, [3])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['in8'],
+ ['out16'],
+ to=getattr(TensorProto, 'BFLOAT16')
+ ),
+ helper.make_node(
+ 'Cast',
+ ['in16'],
+ ['out8'],
+ to=getattr(TensorProto, 'INT8')
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'unstable_types',
+ [IN8, IN16],
+ [OUT8, OUT16],
+)
+model_def = helper.make_model(graph_def, producer_name='unstable_types.py', opset_imports=[onnx.OperatorSetIdProto(version=13)])
+onnx.save(model_def, 'unstable_types.onnx')
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
index e9758f2ddc8..3a593f491d8 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
@@ -38,16 +38,17 @@ struct TypifyOnnxElementType {
template <typename T> using Result = TypifyResultType<T>;
template <typename F> static decltype(auto) resolve(Onnx::ElementType value, F &&f) {
switch(value) {
- case Onnx::ElementType::INT8: return f(Result<int8_t>());
- case Onnx::ElementType::INT16: return f(Result<int16_t>());
- case Onnx::ElementType::INT32: return f(Result<int32_t>());
- case Onnx::ElementType::INT64: return f(Result<int64_t>());
- case Onnx::ElementType::UINT8: return f(Result<uint8_t>());
- case Onnx::ElementType::UINT16: return f(Result<uint16_t>());
- case Onnx::ElementType::UINT32: return f(Result<uint32_t>());
- case Onnx::ElementType::UINT64: return f(Result<uint64_t>());
- case Onnx::ElementType::FLOAT: return f(Result<float>());
- case Onnx::ElementType::DOUBLE: return f(Result<double>());
+ case Onnx::ElementType::INT8: return f(Result<Int8Float>());
+ case Onnx::ElementType::INT16: return f(Result<int16_t>());
+ case Onnx::ElementType::INT32: return f(Result<int32_t>());
+ case Onnx::ElementType::INT64: return f(Result<int64_t>());
+ case Onnx::ElementType::UINT8: return f(Result<uint8_t>());
+ case Onnx::ElementType::UINT16: return f(Result<uint16_t>());
+ case Onnx::ElementType::UINT32: return f(Result<uint32_t>());
+ case Onnx::ElementType::UINT64: return f(Result<uint64_t>());
+ case Onnx::ElementType::BFLOAT16: return f(Result<BFloat16>());
+ case Onnx::ElementType::FLOAT: return f(Result<float>());
+ case Onnx::ElementType::DOUBLE: return f(Result<double>());
}
abort();
}
@@ -118,32 +119,34 @@ auto convert_optimize(Onnx::Optimize optimize) {
CellType to_cell_type(Onnx::ElementType type) {
switch (type) {
- case Onnx::ElementType::INT8: [[fallthrough]];
- case Onnx::ElementType::INT16: [[fallthrough]];
- case Onnx::ElementType::UINT8: [[fallthrough]];
- case Onnx::ElementType::UINT16: [[fallthrough]];
- case Onnx::ElementType::FLOAT: return CellType::FLOAT;
- case Onnx::ElementType::INT32: [[fallthrough]];
- case Onnx::ElementType::INT64: [[fallthrough]];
- case Onnx::ElementType::UINT32: [[fallthrough]];
- case Onnx::ElementType::UINT64: [[fallthrough]];
- case Onnx::ElementType::DOUBLE: return CellType::DOUBLE;
+ case Onnx::ElementType::INT8: return CellType::INT8;
+ case Onnx::ElementType::BFLOAT16: return CellType::BFLOAT16;
+ case Onnx::ElementType::UINT8: [[fallthrough]];
+ case Onnx::ElementType::INT16: [[fallthrough]];
+ case Onnx::ElementType::UINT16: [[fallthrough]];
+ case Onnx::ElementType::FLOAT: return CellType::FLOAT;
+ case Onnx::ElementType::INT32: [[fallthrough]];
+ case Onnx::ElementType::INT64: [[fallthrough]];
+ case Onnx::ElementType::UINT32: [[fallthrough]];
+ case Onnx::ElementType::UINT64: [[fallthrough]];
+ case Onnx::ElementType::DOUBLE: return CellType::DOUBLE;
}
abort();
}
Onnx::ElementType make_element_type(ONNXTensorElementDataType element_type) {
switch (element_type) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return Onnx::ElementType::INT8;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return Onnx::ElementType::INT16;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Onnx::ElementType::INT32;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Onnx::ElementType::INT64;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return Onnx::ElementType::UINT8;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return Onnx::ElementType::UINT16;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: return Onnx::ElementType::UINT32;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: return Onnx::ElementType::UINT64;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Onnx::ElementType::FLOAT;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return Onnx::ElementType::DOUBLE;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return Onnx::ElementType::INT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return Onnx::ElementType::INT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Onnx::ElementType::INT32;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Onnx::ElementType::INT64;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return Onnx::ElementType::UINT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return Onnx::ElementType::UINT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: return Onnx::ElementType::UINT32;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: return Onnx::ElementType::UINT64;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return Onnx::ElementType::BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Onnx::ElementType::FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return Onnx::ElementType::DOUBLE;
default:
throw Ort::Exception(fmt("[onnx wrapper] unsupported element type: %d", element_type), ORT_FAIL);
}
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.h b/eval/src/vespa/eval/onnx/onnx_wrapper.h
index 68c31f04cdc..1f36d576c33 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.h
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.h
@@ -51,7 +51,7 @@ public:
};
// supported onnx element types
- enum class ElementType { INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, FLOAT, DOUBLE };
+ enum class ElementType { INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, BFLOAT16, FLOAT, DOUBLE };
// information about a single input or output tensor
struct TensorInfo {