aboutsummaryrefslogtreecommitdiffstats
path: root/vespalib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-23 12:24:11 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-23 12:24:11 +0000
commitea7f446ac7f29fbbde579513ee52f0943b8889a7 (patch)
treefbacd06e2f340fce224a821170d8d8fa70d9bdc3 /vespalib
parentbe9927c372d9762a03bf687dbe2a886750867a1c (diff)
ensure vespa BFloat16 and onnxruntime BFloat16 behave the same
Diffstat (limited to 'vespalib')
-rw-r--r--vespalib/src/tests/util/bfloat16/CMakeLists.txt2
-rw-r--r--vespalib/src/tests/util/bfloat16/bfloat16_test.cpp89
2 files changed, 90 insertions, 1 deletions
diff --git a/vespalib/src/tests/util/bfloat16/CMakeLists.txt b/vespalib/src/tests/util/bfloat16/CMakeLists.txt
index 39a42e6f148..1916c7bb365 100644
--- a/vespalib/src/tests/util/bfloat16/CMakeLists.txt
+++ b/vespalib/src/tests/util/bfloat16/CMakeLists.txt
@@ -5,5 +5,7 @@ vespa_add_executable(vespalib_bfloat16_test_app TEST
DEPENDS
vespalib
GTest::GTest
+ EXTERNAL_DEPENDS
+ onnxruntime
)
vespa_add_test(NAME vespalib_bfloat16_test_app COMMAND vespalib_bfloat16_test_app)
diff --git a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp
index 5d70c95b1d9..a9d2ab0b8d4 100644
--- a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp
+++ b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp
@@ -5,8 +5,9 @@
#include <vespa/vespalib/gtest/gtest.h>
#include <stdio.h>
#include <cmath>
-#include <cmath>
#include <vector>
+#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
+#include <onnxruntime/onnxruntime_cxx_api.h>
using namespace vespalib;
@@ -165,4 +166,90 @@ TEST(BFloat16Test, check_special_values) {
EXPECT_EQ(memcmp(&f_snan, &f_from_b_snan, sizeof(float)), 0);
}
+#include <onnxruntime/core/framework/endian.h>
+
+// extract from onnx-internal header file:
+namespace onnxruntime {
+
+//BFloat16
+struct BFloat16 {
+ uint16_t val{0};
+ explicit BFloat16() = default;
+ explicit BFloat16(uint16_t v) : val(v) {}
+ explicit BFloat16(float v) {
+ if (endian::native == endian::little) {
+ std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
+ } else {
+ std::memcpy(&val, &v, sizeof(uint16_t));
+ }
+ }
+
+ float ToFloat() const {
+ float result;
+ char* const first = reinterpret_cast<char*>(&result);
+ char* const second = first + sizeof(uint16_t);
+ if (endian::native == endian::little) {
+ std::memset(first, 0, sizeof(uint16_t));
+ std::memcpy(second, &val, sizeof(uint16_t));
+ } else {
+ std::memcpy(first, &val, sizeof(uint16_t));
+ std::memset(second, 0, sizeof(uint16_t));
+ }
+ return result;
+ }
+
+ operator float() const {
+ return ToFloat();
+ }
+};
+
+} // namespace onnxruntime
+
+TEST(OnnxBFloat16Test, has_same_encoding) {
+ vespalib::BFloat16 our_value;
+ uint32_t ok_count = 0;
+ uint32_t nan_count = 0;
+ for (uint32_t i = 0; i < (1u << 16u); ++i) {
+ uint16_t bits = i;
+ our_value.assign_bits(bits);
+ onnxruntime::BFloat16 their_value(bits);
+ if (our_value.get_bits() != bits) {
+ printf("bad bits %04x -> %04x (vespalib)\n", bits, our_value.get_bits());
+ printf("onnx converts -> %04x\n", their_value.val);
+ EXPECT_EQ(our_value.get_bits(), their_value.val);
+ continue;
+ }
+ EXPECT_EQ(their_value.val, bits);
+ if (their_value.val != bits) {
+ printf("bad bits %04x -> %04x (onnx)\n", bits, their_value.val);
+ continue;
+ }
+ EXPECT_EQ(our_value.get_bits(), their_value.val);
+ if (our_value.get_bits() != their_value.val) {
+ printf("vespalib bits %04x != %04x onnx bits\n", our_value.get_bits(), their_value.val);
+ printf("corresponds to floats %g and %g\n", our_value.to_float(), their_value.ToFloat());
+ continue;
+ }
+ float our_float = our_value.to_float();
+ float their_float = their_value.ToFloat();
+ EXPECT_EQ(std::isnan(our_float), std::isnan(their_float));
+ if (std::isnan(our_float) && std::isnan(their_float)) {
+ ++nan_count;
+ continue;
+ }
+ if (our_float != their_float) {
+ printf("bits %04x as float differs: vespalib %g != %g onnx\n", bits, our_value.to_float(), their_value.ToFloat());
+ } else {
+ ++ok_count;
+ }
+ EXPECT_EQ(our_float, their_float);
+ vespalib::BFloat16 our_back(our_float);
+ onnxruntime::BFloat16 their_back(their_float);
+ EXPECT_EQ(our_back.get_bits(), their_back.val);
+ }
+ printf("normal floats behave equally OK in both vespalib and onnx: %d (0x%04x)\n", ok_count, ok_count);
+ printf("floats that are NaN in both vespalib and onnx: %d (0x%04x)\n", nan_count, nan_count);
+ printf("total count (OK + NaN): %d (0x%04x)\n", ok_count + nan_count, ok_count + nan_count);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()