diff options
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.cpp | 65 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.h | 18 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/arrayref.h | 1 |
3 files changed, 42 insertions, 42 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp index 59ea82c83bb..06e6e9157ae 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp @@ -112,9 +112,8 @@ DotProductExecutorBase<BaseType>::~DotProductExecutorBase() = default; template <typename BaseType> void DotProductExecutorBase<BaseType>::execute(uint32_t docId) { feature_t val = 0; - const AT * values(nullptr); - uint32_t sz = getAttributeValues(docId, values); - for (size_t i = 0; i < sz; ++i) { + auto values = getAttributeValues(docId); + for (size_t i = 0; i < values.size(); ++i) { auto itr = _queryVector.getDimMap().find(values[i].value()); if (itr != _end) { val += values[i].weight() * itr->second; @@ -143,10 +142,12 @@ template <typename A> DotProductExecutor<A>::~DotProductExecutor() = default; template <typename A> -size_t -DotProductExecutor<A>::getAttributeValues(uint32_t docId, const AT * & values) +vespalib::ConstArrayRef<typename DotProductExecutor<A>::AT> +DotProductExecutor<A>::getAttributeValues(uint32_t docId) { - return _attribute->getRawValues(docId, values); + const AT* values = nullptr; + auto size = _attribute->getRawValues(docId, values); + return vespalib::ConstArrayRef(values, size); } namespace { @@ -271,12 +272,11 @@ DotProductExecutorBase<BaseType>::~DotProductExecutorBase() = default; template <typename BaseType> void DotProductExecutorBase<BaseType>::execute(uint32_t docId) { - const AT *values(nullptr); - size_t count = getAttributeValues(docId, values); - size_t commonRange = std::min(count, _queryVector.size()); + auto values = getAttributeValues(docId); + size_t commonRange = std::min(values.size(), _queryVector.size()); static_assert(std::is_same<typename AT::ValueType, BaseType>::value); outputs().set_number(0, _multiplier.dotProduct( - &_queryVector[0], reinterpret_cast<const typename AT::ValueType *>(values), commonRange)); + &_queryVector[0], reinterpret_cast<const typename AT::ValueType *>(values.data()), commonRange)); } template <typename A> @@ -290,10 +290,12 @@ template <typename A> DotProductExecutor<A>::~DotProductExecutor() = default; template <typename A> -size_t -DotProductExecutor<A>::getAttributeValues(uint32_t docId, const AT * & values) +vespalib::ConstArrayRef<typename DotProductExecutor<A>::AT> +DotProductExecutor<A>::getAttributeValues(uint32_t docId) { - return _attribute->getRawValues(docId, values); + const AT* values = nullptr; + auto size = _attribute->getRawValues(docId, values); + return vespalib::ConstArrayRef(values, size); } template <typename A> @@ -308,17 +310,16 @@ template <typename A> SparseDotProductExecutor<A>::~SparseDotProductExecutor() = default; template <typename A> -size_t -SparseDotProductExecutor<A>::getAttributeValues(uint32_t docId, const AT * & values) +vespalib::ConstArrayRef<typename SparseDotProductExecutor<A>::AT> +SparseDotProductExecutor<A>::getAttributeValues(uint32_t docId) { - const AT *allValues(NULL); + const AT* allValues(nullptr); size_t count = this->_attribute->getRawValues(docId, allValues); - values = &_scratch[0]; size_t i(0); for (; (i < _queryIndexes.size()) && (_queryIndexes[i] < count); i++) { _scratch[i] = allValues[_queryIndexes[i]]; } - return i; + return vespalib::ConstArrayRef(_scratch.data(), i); } template <typename A> @@ -332,16 +333,15 @@ template <typename A> DotProductByCopyExecutor<A>::~DotProductByCopyExecutor() = default; template <typename A> -size_t -DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId, const AT * & values) +vespalib::ConstArrayRef<typename DotProductByCopyExecutor<A>::AT> +DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) { size_t count = this->_attribute->getAll(docId, &_copy[0], _copy.size()); if (count > _copy.size()) { _copy.resize(count); count = this->_attribute->getAll(docId, &_copy[0], _copy.size()); } - values = reinterpret_cast<const AT *>(&_copy[0]); - return count; + return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_copy.data()), count); } template <typename A> @@ -355,8 +355,8 @@ template <typename A> SparseDotProductByCopyExecutor<A>::~SparseDotProductByCopyExecutor() = default; template <typename A> -size_t -SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId, const AT * & values) +vespalib::ConstArrayRef<typename SparseDotProductByCopyExecutor<A>::AT> +SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) { size_t count = this->_attribute->getAll(docId, &_copy[0], _copy.size()); if (count > _copy.size()) { @@ -367,8 +367,7 @@ SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId, const AT * for (const IV & iv(this->_queryIndexes); (i < iv.size()) && (iv[i] < count); i++) { _copy[i] = _copy[iv[i]]; } - values = reinterpret_cast<const AT *>(&_copy[0]); - return i; + return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_copy.data()), i); } template <typename BaseType> @@ -403,11 +402,11 @@ constexpr void sanity_check_reinterpret_cast_compatibility() { } template <typename BaseType> -size_t DotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid, const AT * & values) { +vespalib::ConstArrayRef<typename DotProductByContentFillExecutor<BaseType>::AT> +DotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid) { _filler.fill(*_attribute, docid); sanity_check_reinterpret_cast_compatibility<BaseType, AT, decltype(*_filler.data())>(); - values = reinterpret_cast<const AT *>(_filler.data()); - return _filler.size(); + return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_filler.data()), _filler.size()); } template <typename BaseType> @@ -427,7 +426,9 @@ template <typename BaseType> SparseDotProductByContentFillExecutor<BaseType>::~SparseDotProductByContentFillExecutor() = default; template <typename BaseType> -size_t SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid, const AT * & values) { +vespalib::ConstArrayRef<typename SparseDotProductByContentFillExecutor<BaseType>::AT> +SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid) +{ _filler.fill(*_attribute, docid); const size_t count = _filler.size(); @@ -438,11 +439,9 @@ size_t SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint3 } sanity_check_reinterpret_cast_compatibility<BaseType, AT, decltype(*_filler.data())>(); - values = reinterpret_cast<const AT *>(data); - return i; + return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(data), i); } - } namespace { diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h index ad5864f86c7..916efa54e11 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h @@ -124,7 +124,7 @@ public: private: const V & _queryVector; const typename V::HashMap::const_iterator _end; - virtual size_t getAttributeValues(uint32_t docid, const AT * & count) = 0; + virtual vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) = 0; public: DotProductExecutorBase(const V & queryVector); ~DotProductExecutorBase() override; @@ -140,7 +140,7 @@ protected: const A * _attribute; private: std::unique_ptr<V> _backing; - size_t getAttributeValues(uint32_t docid, const AT * & count) override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; public: DotProductExecutor(const A * attribute, const V & queryVector); DotProductExecutor(const A * attribute, std::unique_ptr<V> queryVector); @@ -183,7 +183,7 @@ public: private: const vespalib::hwaccelrated::IAccelrated & _multiplier; V _queryVector; - virtual size_t getAttributeValues(uint32_t docid, const AT * & count) = 0; + virtual vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) = 0; public: DotProductExecutorBase(const V & queryVector); ~DotProductExecutorBase() override; @@ -201,7 +201,7 @@ public: protected: const A * _attribute; private: - size_t getAttributeValues(uint32_t docid, const AT * & count) override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; public: DotProductExecutor(const A * attribute, const V & queryVector); ~DotProductExecutor(); @@ -215,7 +215,7 @@ public: ~DotProductByCopyExecutor(); private: typedef typename DotProductExecutor<A>::AT AT; - size_t getAttributeValues(uint32_t docid, const AT * & count) final override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; std::vector<typename A::BaseType> _copy; }; @@ -240,7 +240,7 @@ public: DotProductByContentFillExecutor(const attribute::IAttributeVector * attribute, const V & queryVector); ~DotProductByContentFillExecutor(); private: - size_t getAttributeValues(uint32_t docid, const AT * & values) final override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; const attribute::IAttributeVector* _attribute; ValueFiller _filler; @@ -255,7 +255,7 @@ public: ~SparseDotProductExecutor(); private: typedef typename DotProductExecutor<A>::AT AT; - size_t getAttributeValues(uint32_t docid, const AT * & count) override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; protected: IV _queryIndexes; std::vector<AT> _scratch; @@ -270,7 +270,7 @@ public: ~SparseDotProductByCopyExecutor(); private: typedef typename DotProductExecutor<A>::AT AT; - size_t getAttributeValues(uint32_t docid, const AT * & count) final override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; std::vector<typename A::BaseType> _copy; }; @@ -291,7 +291,7 @@ public: const IV & queryIndexes); ~SparseDotProductByContentFillExecutor() override; private: - size_t getAttributeValues(uint32_t docid, const AT * & values) final override; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; const attribute::IAttributeVector* _attribute; IV _queryIndexes; diff --git a/vespalib/src/vespa/vespalib/util/arrayref.h b/vespalib/src/vespa/vespalib/util/arrayref.h index bc1fc540a6c..337833d2457 100644 --- a/vespalib/src/vespa/vespalib/util/arrayref.h +++ b/vespalib/src/vespa/vespalib/util/arrayref.h @@ -50,6 +50,7 @@ public: const T *cend() const { return _v + _sz; } const T *begin() const { return _v; } const T *end() const { return _v + _sz; } + const T *data() const noexcept { return _v; } private: const T *_v; size_t _sz; |