diff options
Diffstat (limited to 'eval/src/tests/ann/nns-l2.h')
-rw-r--r-- | eval/src/tests/ann/nns-l2.h | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/eval/src/tests/ann/nns-l2.h b/eval/src/tests/ann/nns-l2.h index cfa5fed704f..dcad5f1bda6 100644 --- a/eval/src/tests/ann/nns-l2.h +++ b/eval/src/tests/ann/nns-l2.h @@ -1,8 +1,36 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <string.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> +template <typename T, size_t VLEN> +static double hw_l2_sq_dist(const T * af, const T * bf, size_t sz) +{ + constexpr const size_t OpsPerV = VLEN/sizeof(T); + typedef T V __attribute__ ((vector_size (VLEN), aligned(VLEN))); + + const V * a = reinterpret_cast<const V *>(af); + const V * b = reinterpret_cast<const V *>(bf); + + V tmp_diff; + V tmp_squa; + V tmp_sum; + memset(&tmp_sum, 0, sizeof(tmp_sum)); + + const size_t numOps = sz/OpsPerV; + for (size_t i = 0; i < numOps; ++i) { + tmp_diff = a[i] - b[i]; + tmp_squa = tmp_diff * tmp_diff; + tmp_sum += tmp_squa; + } + double sum = 0; + for (size_t i = 0; i < OpsPerV; ++i) { + sum += tmp_sum[i]; + } + return sum; +} + template <typename FltType = float> struct L2DistCalc { vespalib::hwaccelrated::IAccelrated::UP _hw; @@ -11,7 +39,10 @@ struct L2DistCalc { using Arr = vespalib::ArrayRef<FltType>; using ConstArr = vespalib::ConstArrayRef<FltType>; - + + double product(const FltType *v1, const FltType *v2, size_t sz) { + return _hw->dotProduct(v1, v2, sz); + } double product(ConstArr v1, ConstArr v2) { const FltType *p1 = v1.begin(); const FltType *p2 = v2.begin(); @@ -28,9 +59,7 @@ struct L2DistCalc { return l2sq(tmp); } double l2sq_dist(ConstArr v1, ConstArr v2) { - std::vector<FltType> tmp; - tmp.resize(v1.size()); - return l2sq_dist(v1, v2, Arr(tmp)); + return hw_l2_sq_dist<FltType, 32>(v1.cbegin(), v2.cbegin(), v1.size()); } }; |