diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-03-29 08:44:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-29 08:44:12 +0200 |
commit | a3b2ae2a910f60bbd212bd6c126492a9933372c6 (patch) | |
tree | 33572cc9ce9c05570333a5427ece50a8dbb2e7d3 /vespalib | |
parent | ac157c24c794028a1cdc446778e138fdc90ac918 (diff) | |
parent | 19d06063a500baa15383e5bc73eaf1ce421a70f2 (diff) |
Merge pull request #17130 from vespa-engine/arnej/compare-with-onnx-bfloat16
ensure vespa BFloat16 and onnxruntime BFloat16 behave the same
Diffstat (limited to 'vespalib')
-rw-r--r-- | vespalib/src/tests/util/bfloat16/bfloat16_test.cpp | 88 |
1 files changed, 87 insertions, 1 deletions
diff --git a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp index 5d70c95b1d9..d8bb93d7972 100644 --- a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp +++ b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp @@ -5,7 +5,7 @@ #include <vespa/vespalib/gtest/gtest.h> #include <stdio.h> #include <cmath> -#include <cmath> +#include <cstring> #include <vector> using namespace vespalib; @@ -165,4 +165,90 @@ TEST(BFloat16Test, check_special_values) { EXPECT_EQ(memcmp(&f_snan, &f_from_b_snan, sizeof(float)), 0); } +#include <onnxruntime/core/framework/endian.h> + +// extract from onnx-internal header file: +namespace onnxruntime { + +//BFloat16 +struct BFloat16 { + uint16_t val{0}; + explicit BFloat16() = default; + explicit BFloat16(uint16_t v) : val(v) {} + explicit BFloat16(float v) { + if (endian::native == endian::little) { + std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&val, &v, sizeof(uint16_t)); + } + } + + float ToFloat() const { + float result; + char* const first = reinterpret_cast<char*>(&result); + char* const second = first + sizeof(uint16_t); + if (endian::native == endian::little) { + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; + } +}; + +} // namespace onnxruntime + +TEST(OnnxBFloat16Test, has_same_encoding) { + EXPECT_EQ(sizeof(vespalib::BFloat16), sizeof(onnxruntime::BFloat16)); + EXPECT_EQ(sizeof(vespalib::BFloat16), sizeof(uint16_t)); + EXPECT_EQ(sizeof(onnxruntime::BFloat16), sizeof(uint16_t)); + vespalib::BFloat16 our_value; + uint32_t ok_count = 0; + uint32_t nan_count = 0; + for (uint32_t i = 0; i < (1u << 16u); ++i) { + uint16_t bits = i; + our_value.assign_bits(bits); + onnxruntime::BFloat16 their_value(bits); + if (our_value.get_bits() != bits) { + printf("bad bits %04x -> %04x (vespalib)\n", bits, our_value.get_bits()); + printf("onnx converts -> %04x\n", their_value.val); + EXPECT_EQ(our_value.get_bits(), their_value.val); + continue; + } + EXPECT_EQ(their_value.val, bits); + EXPECT_EQ(memcmp(&our_value, &their_value, sizeof(our_value)), 0); + if (their_value.val != bits) { + printf("bad bits %04x -> %04x (onnx)\n", bits, their_value.val); + continue; + } + EXPECT_EQ(our_value.get_bits(), their_value.val); + if (our_value.get_bits() != their_value.val) { + printf("vespalib bits %04x != %04x onnx bits\n", our_value.get_bits(), their_value.val); + printf("corresponds to floats %g and %g\n", our_value.to_float(), their_value.ToFloat()); + continue; + } + float our_float = our_value.to_float(); + float their_float = their_value.ToFloat(); + EXPECT_EQ(std::isnan(our_float), std::isnan(their_float)); + if (std::isnan(our_float) && std::isnan(their_float)) { + ++nan_count; + continue; + } + if (our_float != their_float) { + printf("bits %04x as float differs: vespalib %g != %g onnx\n", bits, our_value.to_float(), their_value.ToFloat()); + } else { + ++ok_count; + } + EXPECT_EQ(our_float, their_float); + vespalib::BFloat16 our_back(our_float); + onnxruntime::BFloat16 their_back(their_float); + EXPECT_EQ(our_back.get_bits(), their_back.val); + } + printf("normal floats behave equally OK in both vespalib and onnx: %d (0x%04x)\n", ok_count, ok_count); + printf("floats that are NaN in both vespalib and onnx: %d (0x%04x)\n", nan_count, nan_count); + printf("total count (OK + NaN): %d (0x%04x)\n", ok_count + nan_count, ok_count + nan_count); +} + GTEST_MAIN_RUN_ALL_TESTS() |