aboutsummaryrefslogtreecommitdiffstats
path: root/vespalib/src/vespa/vespalib/util/bfloat16.h
blob: 0fa201915252d00e069208208ebc540719cbfd62 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <bit>
#include <cstdint>
#include <cstring>
#include <limits>

namespace vespalib {

/**
 * Class holding 16-bit floating-point numbers.
 * Truncated version of normal 32-bit float; the sign and
 * exponent are kept as-is but the mantissa has only 8-bit
 * precision.  Well suited for ML / AI, halving memory
 * requirements for large vectors and similar data.
 * Direct HW support possible (AVX-512 BF16 extension etc.)
 * See also:
 * https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
 **/
class BFloat16 {
private:
    uint16_t _bits;
    struct TwoU16 {
        uint16_t u1;
        uint16_t u2;
    };

    template<std::endian native_endian = std::endian::native>
    static constexpr uint16_t float_to_bits(float value) noexcept {
        TwoU16 both{0,0};
        static_assert(sizeof(TwoU16) == sizeof(float));
        memcpy(&both, &value, sizeof(float));
        if constexpr (native_endian == std::endian::big) {
            return both.u1;
        } else {
            static_assert(native_endian == std::endian::little,
                          "Unknown endian, cannot handle");
            return both.u2;
        }
    }

    template<std::endian native_endian = std::endian::native>
    static constexpr float bits_to_float(uint16_t bits) noexcept {
        TwoU16 both{0,0};
        if constexpr (native_endian == std::endian::big) {
            both.u1 = bits;
        } else {
            static_assert(native_endian == std::endian::little,
                          "Unknown endian, cannot handle");
            both.u2 = bits;
        }
        float result = 0.0;
        static_assert(sizeof(TwoU16) == sizeof(float));
        memcpy(&result, &both, sizeof(float));
        return result;
    }
public:
    constexpr BFloat16(float value) noexcept : _bits(float_to_bits(value)) {}
    BFloat16() noexcept = default;
    ~BFloat16() noexcept = default;
    constexpr BFloat16(const BFloat16 &other) noexcept = default;
    constexpr BFloat16(BFloat16 &&other) noexcept = default;
    constexpr BFloat16& operator=(const BFloat16 &other) noexcept = default;
    constexpr BFloat16& operator=(BFloat16 &&other) noexcept = default;
    constexpr BFloat16& operator=(float value) noexcept {
        _bits = float_to_bits(value);
        return *this;
    }

    constexpr operator float () const noexcept { return bits_to_float(_bits); }

    constexpr float to_float() const noexcept { return bits_to_float(_bits); }
    constexpr void assign(float value) noexcept { _bits = float_to_bits(value); }

    constexpr uint16_t get_bits() const { return _bits; }
    constexpr void assign_bits(uint16_t value) noexcept { _bits = value; }
};

}

namespace std {
template<> class numeric_limits<vespalib::BFloat16> {
public:
    static constexpr bool is_specialized = true;
    static constexpr bool is_signed = true;
    static constexpr bool is_integer = false;
    static constexpr bool is_exact = false;
    static constexpr bool has_infinity = false;
    static constexpr bool has_quiet_NaN = true;
    static constexpr bool has_signaling_NaN = true;
    static constexpr bool has_denorm = true;
    static constexpr bool has_denorm_loss = false;
    static constexpr bool is_iec559 = false;
    static constexpr bool is_bounded = true;
    static constexpr bool is_modulo = false;
    static constexpr bool traps = false;
    static constexpr bool tinyness_before = false;

    static constexpr std::float_round_style round_style = std::round_toward_zero;
    static constexpr int radix = 2;

    static constexpr int digits = 8;
    static constexpr int digits10 = 2;
    static constexpr int max_digits10 = 4;

    static constexpr int min_exponent = -125;
    static constexpr int min_exponent10 = -2;

    static constexpr int max_exponent = 128;
    static constexpr int max_exponent10 = 38;

    static constexpr vespalib::BFloat16 denorm_min() noexcept { return 0x1.0p-133; }
    static constexpr vespalib::BFloat16 epsilon() noexcept { return 0x1.0p-7; }
    static constexpr vespalib::BFloat16 lowest() noexcept { return -0x1.FEp127; }
    static constexpr vespalib::BFloat16 max() noexcept { return 0x1.FEp127; }
    static constexpr vespalib::BFloat16 min() noexcept { return 0x1.0p-126; }
    static constexpr vespalib::BFloat16 round_error() noexcept { return 1.0; }
    static constexpr vespalib::BFloat16 infinity() noexcept {
        return std::numeric_limits<float>::infinity();
    }
    static constexpr vespalib::BFloat16 quiet_NaN() noexcept {
        return std::numeric_limits<float>::quiet_NaN();
    }
    static constexpr vespalib::BFloat16 signaling_NaN() noexcept {
        return std::numeric_limits<float>::signaling_NaN();
    }
};

}