From ea7f446ac7f29fbbde579513ee52f0943b8889a7 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 23 Mar 2021 12:24:11 +0000 Subject: ensure vespa BFloat16 and onnxruntime BFloat16 behave the same --- vespalib/src/tests/util/bfloat16/CMakeLists.txt | 2 + vespalib/src/tests/util/bfloat16/bfloat16_test.cpp | 89 +++++++++++++++++++++- 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/vespalib/src/tests/util/bfloat16/CMakeLists.txt b/vespalib/src/tests/util/bfloat16/CMakeLists.txt index 39a42e6f148..1916c7bb365 100644 --- a/vespalib/src/tests/util/bfloat16/CMakeLists.txt +++ b/vespalib/src/tests/util/bfloat16/CMakeLists.txt @@ -5,5 +5,7 @@ vespa_add_executable(vespalib_bfloat16_test_app TEST DEPENDS vespalib GTest::GTest + EXTERNAL_DEPENDS + onnxruntime ) vespa_add_test(NAME vespalib_bfloat16_test_app COMMAND vespalib_bfloat16_test_app) diff --git a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp index 5d70c95b1d9..a9d2ab0b8d4 100644 --- a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp +++ b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp @@ -5,8 +5,9 @@ #include #include #include -#include #include +#include +#include using namespace vespalib; @@ -165,4 +166,90 @@ TEST(BFloat16Test, check_special_values) { EXPECT_EQ(memcmp(&f_snan, &f_from_b_snan, sizeof(float)), 0); } +#include + +// extract from onnx-internal header file: +namespace onnxruntime { + +//BFloat16 +struct BFloat16 { + uint16_t val{0}; + explicit BFloat16() = default; + explicit BFloat16(uint16_t v) : val(v) {} + explicit BFloat16(float v) { + if (endian::native == endian::little) { + std::memcpy(&val, reinterpret_cast(&v) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&val, &v, sizeof(uint16_t)); + } + } + + float ToFloat() const { + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); + if (endian::native == endian::little) { + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; + } + + operator float() const { + return ToFloat(); + } +}; + +} // namespace onnxruntime + +TEST(OnnxBFloat16Test, has_same_encoding) { + vespalib::BFloat16 our_value; + uint32_t ok_count = 0; + uint32_t nan_count = 0; + for (uint32_t i = 0; i < (1u << 16u); ++i) { + uint16_t bits = i; + our_value.assign_bits(bits); + onnxruntime::BFloat16 their_value(bits); + if (our_value.get_bits() != bits) { + printf("bad bits %04x -> %04x (vespalib)\n", bits, our_value.get_bits()); + printf("onnx converts -> %04x\n", their_value.val); + EXPECT_EQ(our_value.get_bits(), their_value.val); + continue; + } + EXPECT_EQ(their_value.val, bits); + if (their_value.val != bits) { + printf("bad bits %04x -> %04x (onnx)\n", bits, their_value.val); + continue; + } + EXPECT_EQ(our_value.get_bits(), their_value.val); + if (our_value.get_bits() != their_value.val) { + printf("vespalib bits %04x != %04x onnx bits\n", our_value.get_bits(), their_value.val); + printf("corresponds to floats %g and %g\n", our_value.to_float(), their_value.ToFloat()); + continue; + } + float our_float = our_value.to_float(); + float their_float = their_value.ToFloat(); + EXPECT_EQ(std::isnan(our_float), std::isnan(their_float)); + if (std::isnan(our_float) && std::isnan(their_float)) { + ++nan_count; + continue; + } + if (our_float != their_float) { + printf("bits %04x as float differs: vespalib %g != %g onnx\n", bits, our_value.to_float(), their_value.ToFloat()); + } else { + ++ok_count; + } + EXPECT_EQ(our_float, their_float); + vespalib::BFloat16 our_back(our_float); + onnxruntime::BFloat16 their_back(their_float); + EXPECT_EQ(our_back.get_bits(), their_back.val); + } + printf("normal floats behave equally OK in both vespalib and onnx: %d (0x%04x)\n", ok_count, ok_count); + printf("floats that are NaN in both vespalib and onnx: %d (0x%04x)\n", nan_count, nan_count); + printf("total count (OK + NaN): %d (0x%04x)\n", ok_count + nan_count, ok_count + nan_count); +} + GTEST_MAIN_RUN_ALL_TESTS() -- cgit v1.2.3