diff options
Diffstat (limited to 'vespalib')
5 files changed, 15 insertions, 1 deletions
diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp index 296aa001e58..3b63a904b2f 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp @@ -40,4 +40,10 @@ Avx2Accelrator::convert_bfloat16_to_float(const uint16_t * src, float * dest, si helper::convert_bfloat16_to_float(src, dest, sz); } +int64_t +Avx2Accelrator::dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept +{ + return helper::multiplyAdd(a, b, sz); +} + } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h index a82cc30eaf4..279483fcec3 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h @@ -17,6 +17,7 @@ public: double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; void convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept override; + int64_t dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; void and128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; void or128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp index 80dc08f24c8..aec81c718c7 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp @@ -50,4 +50,10 @@ Avx512Accelrator::convert_bfloat16_to_float(const uint16_t * src, float * dest, helper::convert_bfloat16_to_float(src, dest, sz); } +int64_t +Avx512Accelrator::dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept +{ + return helper::multiplyAdd(a, b, sz); +} + } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h index 85cb3f62de9..49d2dc63f77 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h @@ -19,6 +19,7 @@ public: double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; void convert_bfloat16_to_float(const uint16_t * src, float * dest, size_t sz) const noexcept override; + int64_t dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; void and128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; void or128(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp b/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp index b2bc087b7e9..fbcf3dff526 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp @@ -130,7 +130,7 @@ ACCUM multiplyAddT(const int8_t *a, const int8_t *b, size_t sz) noexcept { ACCUM sum = 0; for (size_t i(0); i < sz; i++) { - sum += uint16_t(a[i]) * uint16_t(b[i]); + sum += int16_t(a[i]) * int16_t(b[i]); } return sum; } |