diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-05-14 23:04:19 +0000 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-05-14 23:04:19 +0000 |
commit | 51dd79b028db920f0749dd183200455f2f7a1f71 (patch) | |
tree | aba8e53c1d17ce107a0d9719d63515d2896dd116 | |
parent | cf84c1de017cc9e3cfd1b8859ddfbfba41a350e5 (diff) |
Speed up bfloat16 to float conversion
9 files changed, 48 insertions, 6 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp index 4753e9d7c87..5d29f38cf2a 100644 --- a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp @@ -1,11 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "temporary_vector_store.h" +#include <vespa/vespalib/hwaccelrated/iaccelrated.h> using vespalib::ConstArrayRef; using vespalib::ArrayRef; using vespalib::eval::CellType; using vespalib::eval::TypedCells; +using vespalib::hwaccelrated::IAccelrated; namespace search::tensor { @@ -13,18 +15,29 @@ namespace { template<typename FromType, typename ToType> ConstArrayRef<ToType> +convert_cells(ArrayRef<ToType> space, TypedCells cells) noexcept __attribute_noinline__; + +template<typename FromType, typename ToType> +ConstArrayRef<ToType> convert_cells(ArrayRef<ToType> space, TypedCells cells) noexcept { - assert(cells.size == space.size()); - auto old_cells = cells.typify<FromType>(); + auto old_cells = cells.unsafe_typify<FromType>(); ToType *p = space.data(); for (FromType value : old_cells) { - ToType conv(value); - *p++ = conv; + *p++ = value; } return space; } +template<> +ConstArrayRef<float> +convert_cells<vespalib::BFloat16, float>(ArrayRef<float> space, TypedCells cells) noexcept +{ + static const IAccelrated & accelrator = IAccelrated::getAccelerator(); + accelrator.convert_bfloat16_to_float(reinterpret_cast<const uint16_t *>(cells.data), space.data(), space.size()); + return space; +} + template <typename ToType> struct ConvertCellsSelector { diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp index 66441b3c08b..296aa001e58 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp @@ -35,4 +35,9 @@ Avx2Accelrator::or128(size_t offset, const std::vector<std::pair<const void *, b helper::orChunks<32u, 4u>(offset, src, dest); } +void +Avx2Accelrator::convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept { + helper::convert_bfloat16_to_float(src, dest, sz); +} + } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h index af46035666c..a82cc30eaf4 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h @@ -16,6 +16,7 @@ public: double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; + void convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept override; void and128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; void or128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp index 5f408c05fef..80dc08f24c8 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp @@ -45,4 +45,9 @@ Avx512Accelrator::or128(size_t offset, const std::vector<std::pair<const void *, helper::orChunks<64, 2>(offset, src, dest); } +void +Avx512Accelrator::convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept { + helper::convert_bfloat16_to_float(src, dest, sz); +} + } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h index a86a2787d5a..85cb3f62de9 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h @@ -18,6 +18,7 @@ public: double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; + void convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept override; void and128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; void or128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp index f0112aaddf7..4307b38d18b 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp @@ -152,6 +152,11 @@ GenericAccelrator::notBit(void * aOrg, size_t bytes) const noexcept } } +void +GenericAccelrator::convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept { + helper::convert_bfloat16_to_float(src, dest, sz); +} + size_t GenericAccelrator::populationCount(const uint64_t *a, size_t sz) const noexcept { return helper::populationCount(a, sz); diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/generic.h b/vespalib/src/vespa/vespalib/hwaccelrated/generic.h index ba986656635..fee1fec6165 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/generic.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/generic.h @@ -23,6 +23,7 @@ public: void andNotBit(void * a, const void * b, size_t bytes) const noexcept override; void notBit(void * a, size_t bytes) const noexcept override; size_t populationCount(const uint64_t *a, size_t sz) const noexcept override; + void convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept override; double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h b/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h index f070f206b7e..337dc3b4ab1 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h @@ -28,6 +28,7 @@ public: virtual void andNotBit(void * a, const void * b, size_t bytes) const noexcept = 0; virtual void notBit(void * a, size_t bytes) const noexcept = 0; virtual size_t populationCount(const uint64_t *a, size_t sz) const noexcept = 0; + virtual void convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept = 0; virtual double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept = 0; virtual double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept = 0; virtual double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept = 0; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp b/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp index a53716a2973..173fe151831 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp @@ -101,15 +101,25 @@ double squaredEuclideanDistanceT(const int8_t * a, const int8_t * b, size_t sz) inline double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) { - constexpr size_t LOOP_COUNT = 0x10000; + constexpr size_t LOOP_COUNT = 0x200; double sum(0); size_t i=0; for (; i + LOOP_COUNT <= sz; i += LOOP_COUNT) { sum += squaredEuclideanDistanceT<int32_t>(a + i, b + i, LOOP_COUNT); } - sum += squaredEuclideanDistanceT<int32_t>(a + i, b + i, sz - i); + if (sz > i) [[unlikely]] { + sum += squaredEuclideanDistanceT<int32_t>(a + i, b + i, sz - i); + } return sum; } +inline void +convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) noexcept { + uint32_t * asu32 = reinterpret_cast<uint32_t *>(dest); + for (size_t i(0); i < sz; i++) { + asu32[i] = src[i] << 16; + } +} + } } |