summaryrefslogtreecommitdiffstats
path: root/vespalib
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-03-29 08:44:12 +0200
committerGitHub <noreply@github.com>2021-03-29 08:44:12 +0200
commita3b2ae2a910f60bbd212bd6c126492a9933372c6 (patch)
tree33572cc9ce9c05570333a5427ece50a8dbb2e7d3 /vespalib
parentac157c24c794028a1cdc446778e138fdc90ac918 (diff)
parent19d06063a500baa15383e5bc73eaf1ce421a70f2 (diff)
Merge pull request #17130 from vespa-engine/arnej/compare-with-onnx-bfloat16
ensure vespa BFloat16 and onnxruntime BFloat16 behave the same
Diffstat (limited to 'vespalib')
-rw-r--r--vespalib/src/tests/util/bfloat16/bfloat16_test.cpp88
1 files changed, 87 insertions, 1 deletions
diff --git a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp
index 5d70c95b1d9..d8bb93d7972 100644
--- a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp
+++ b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp
@@ -5,7 +5,7 @@
#include <vespa/vespalib/gtest/gtest.h>
#include <stdio.h>
#include <cmath>
-#include <cmath>
+#include <cstring>
#include <vector>
using namespace vespalib;
@@ -165,4 +165,90 @@ TEST(BFloat16Test, check_special_values) {
EXPECT_EQ(memcmp(&f_snan, &f_from_b_snan, sizeof(float)), 0);
}
+#include <onnxruntime/core/framework/endian.h>
+
+// extract from onnx-internal header file:
+namespace onnxruntime {
+
+//BFloat16
+struct BFloat16 {
+ uint16_t val{0};
+ explicit BFloat16() = default;
+ explicit BFloat16(uint16_t v) : val(v) {}
+ explicit BFloat16(float v) {
+ if (endian::native == endian::little) {
+ std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
+ } else {
+ std::memcpy(&val, &v, sizeof(uint16_t));
+ }
+ }
+
+ float ToFloat() const {
+ float result;
+ char* const first = reinterpret_cast<char*>(&result);
+ char* const second = first + sizeof(uint16_t);
+ if (endian::native == endian::little) {
+ std::memset(first, 0, sizeof(uint16_t));
+ std::memcpy(second, &val, sizeof(uint16_t));
+ } else {
+ std::memcpy(first, &val, sizeof(uint16_t));
+ std::memset(second, 0, sizeof(uint16_t));
+ }
+ return result;
+ }
+};
+
+} // namespace onnxruntime
+
+TEST(OnnxBFloat16Test, has_same_encoding) {
+ EXPECT_EQ(sizeof(vespalib::BFloat16), sizeof(onnxruntime::BFloat16));
+ EXPECT_EQ(sizeof(vespalib::BFloat16), sizeof(uint16_t));
+ EXPECT_EQ(sizeof(onnxruntime::BFloat16), sizeof(uint16_t));
+ vespalib::BFloat16 our_value;
+ uint32_t ok_count = 0;
+ uint32_t nan_count = 0;
+ for (uint32_t i = 0; i < (1u << 16u); ++i) {
+ uint16_t bits = i;
+ our_value.assign_bits(bits);
+ onnxruntime::BFloat16 their_value(bits);
+ if (our_value.get_bits() != bits) {
+ printf("bad bits %04x -> %04x (vespalib)\n", bits, our_value.get_bits());
+ printf("onnx converts -> %04x\n", their_value.val);
+ EXPECT_EQ(our_value.get_bits(), their_value.val);
+ continue;
+ }
+ EXPECT_EQ(their_value.val, bits);
+ EXPECT_EQ(memcmp(&our_value, &their_value, sizeof(our_value)), 0);
+ if (their_value.val != bits) {
+ printf("bad bits %04x -> %04x (onnx)\n", bits, their_value.val);
+ continue;
+ }
+ EXPECT_EQ(our_value.get_bits(), their_value.val);
+ if (our_value.get_bits() != their_value.val) {
+ printf("vespalib bits %04x != %04x onnx bits\n", our_value.get_bits(), their_value.val);
+ printf("corresponds to floats %g and %g\n", our_value.to_float(), their_value.ToFloat());
+ continue;
+ }
+ float our_float = our_value.to_float();
+ float their_float = their_value.ToFloat();
+ EXPECT_EQ(std::isnan(our_float), std::isnan(their_float));
+ if (std::isnan(our_float) && std::isnan(their_float)) {
+ ++nan_count;
+ continue;
+ }
+ if (our_float != their_float) {
+ printf("bits %04x as float differs: vespalib %g != %g onnx\n", bits, our_value.to_float(), their_value.ToFloat());
+ } else {
+ ++ok_count;
+ }
+ EXPECT_EQ(our_float, their_float);
+ vespalib::BFloat16 our_back(our_float);
+ onnxruntime::BFloat16 their_back(their_float);
+ EXPECT_EQ(our_back.get_bits(), their_back.val);
+ }
+ printf("normal floats behave equally OK in both vespalib and onnx: %d (0x%04x)\n", ok_count, ok_count);
+ printf("floats that are NaN in both vespalib and onnx: %d (0x%04x)\n", nan_count, nan_count);
+ printf("total count (OK + NaN): %d (0x%04x)\n", ok_count + nan_count, ok_count + nan_count);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()