diff options
author | Tor Egge <Tor.Egge@yahooinc.com> | 2022-04-12 14:52:41 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@yahooinc.com> | 2022-04-12 15:23:08 +0200 |
commit | bc85555178efa7dafe1e790a21906378402c398e (patch) | |
tree | 2810b0f4e7ff847283e97aa927dda8403b1cad33 /searchlib | |
parent | 063a689b449c9d3e8ba247f2dab95819b4c0ae36 (diff) |
Simplify dot product executors for array attribute vectors.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.cpp | 42 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.h | 25 |
2 files changed, 19 insertions, 48 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp index 0b4e09dde0d..b279cf6ca08 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp @@ -272,9 +272,8 @@ template <typename BaseType> void DotProductExecutorBase<BaseType>::execute(uint32_t docId) { auto values = getAttributeValues(docId); size_t commonRange = std::min(values.size(), _queryVector.size()); - static_assert(std::is_same_v<multivalue::ValueType_t<AT>, BaseType>); outputs().set_number(0, _multiplier.dotProduct( - &_queryVector[0], reinterpret_cast<const multivalue::ValueType_t<AT> *>(values.data()), commonRange)); + &_queryVector[0], values.data(), commonRange)); } template <typename BaseType> @@ -288,7 +287,7 @@ template <typename BaseType> DotProductByArrayReadViewExecutor<BaseType>::~DotProductByArrayReadViewExecutor() = default; template <typename BaseType> -vespalib::ConstArrayRef<typename DotProductByArrayReadViewExecutor<BaseType>::AT> +vespalib::ConstArrayRef<BaseType> DotProductByArrayReadViewExecutor<BaseType>::getAttributeValues(uint32_t docId) { return _array_read_view->get_raw_values(docId); @@ -326,7 +325,7 @@ template <typename BaseType> SparseDotProductByArrayReadViewExecutor<BaseType>::~SparseDotProductByArrayReadViewExecutor() = default; template <typename BaseType> -vespalib::ConstArrayRef<typename SparseDotProductByArrayReadViewExecutor<BaseType>::AT> +vespalib::ConstArrayRef<BaseType> SparseDotProductByArrayReadViewExecutor<BaseType>::getAttributeValues(uint32_t docid) { auto allValues = _array_read_view->get_raw_values(docid); @@ -348,7 +347,7 @@ template <typename A> DotProductByCopyExecutor<A>::~DotProductByCopyExecutor() = default; template <typename A> -vespalib::ConstArrayRef<typename DotProductByCopyExecutor<A>::AT> +vespalib::ConstArrayRef<typename A::BaseType> DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) { size_t count = this->_attribute->getAll(docId, &_copy[0], _copy.size()); @@ -356,7 +355,7 @@ DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) _copy.resize(count); count = this->_attribute->getAll(docId, &_copy[0], _copy.size()); } - return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_copy.data()), count); + return vespalib::ConstArrayRef(_copy.data(), count); } template <typename A> @@ -371,7 +370,7 @@ template <typename A> SparseDotProductByCopyExecutor<A>::~SparseDotProductByCopyExecutor() = default; template <typename A> -vespalib::ConstArrayRef<typename SparseDotProductByCopyExecutor<A>::AT> +vespalib::ConstArrayRef<typename A::BaseType> SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) { size_t count = this->_attribute->getAll(docId, &_copy[0], _copy.size()); @@ -383,7 +382,7 @@ SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) for (const IV & iv(this->_queryIndexes); (i < iv.size()) && (iv[i] < count); i++) { _copy[i] = _copy[iv[i]]; } - return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_copy.data()), i); + return vespalib::ConstArrayRef(_copy.data(), i); } template <typename BaseType> @@ -400,29 +399,11 @@ DotProductByContentFillExecutor<BaseType>::DotProductByContentFillExecutor( template <typename BaseType> DotProductByContentFillExecutor<BaseType>::~DotProductByContentFillExecutor() = default; -namespace { - -template<typename T> struct IsNonWeightedType : std::true_type {}; -template<typename BaseType> struct IsNonWeightedType<multivalue::WeightedValue<BaseType>> : std::false_type {}; - -// Compile-time sanity check for type compatibility of gnarly BaseType <-> multivalue::Value -// reinterpret_cast used by some getAttributeValues calls. -template <typename BaseType, typename AttributeValueType, typename FillerValueType> -constexpr void sanity_check_reinterpret_cast_compatibility() { - static_assert(IsNonWeightedType<AttributeValueType>::value); - static_assert(sizeof(BaseType) == sizeof(AttributeValueType)); - static_assert(sizeof(BaseType) == sizeof(FillerValueType)); - static_assert(std::is_same_v<BaseType, multivalue::ValueType_t<AttributeValueType>>); -} - -} - template <typename BaseType> -vespalib::ConstArrayRef<typename DotProductByContentFillExecutor<BaseType>::AT> +vespalib::ConstArrayRef<BaseType> DotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid) { _filler.fill(*_attribute, docid); - sanity_check_reinterpret_cast_compatibility<BaseType, AT, decltype(*_filler.data())>(); - return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_filler.data()), _filler.size()); + return vespalib::ConstArrayRef(_filler.data(), _filler.size()); } template <typename BaseType> @@ -442,7 +423,7 @@ template <typename BaseType> SparseDotProductByContentFillExecutor<BaseType>::~SparseDotProductByContentFillExecutor() = default; template <typename BaseType> -vespalib::ConstArrayRef<typename SparseDotProductByContentFillExecutor<BaseType>::AT> +vespalib::ConstArrayRef<BaseType> SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid) { _filler.fill(*_attribute, docid); @@ -454,8 +435,7 @@ SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t doc data[i] = data[_queryIndexes[i]]; } - sanity_check_reinterpret_cast_compatibility<BaseType, AT, decltype(*_filler.data())>(); - return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(data), i); + return vespalib::ConstArrayRef(data, i); } } diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h index 8ce545bd23d..051a08025d8 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h @@ -180,12 +180,11 @@ namespace array { template <typename BaseType> class DotProductExecutorBase : public fef::FeatureExecutor { public: - using AT = BaseType; using V = std::vector<BaseType>; private: const vespalib::hwaccelrated::IAccelrated & _multiplier; V _queryVector; - virtual vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) = 0; + virtual vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) = 0; public: DotProductExecutorBase(const V & queryVector); ~DotProductExecutorBase() override; @@ -198,13 +197,12 @@ public: template <typename BaseType> class DotProductByArrayReadViewExecutor : public DotProductExecutorBase<BaseType> { public: - using AT = typename DotProductExecutorBase<BaseType>::AT; using V = typename DotProductExecutorBase<BaseType>::V; using ArrayReadView = attribute::IArrayReadView<BaseType>; protected: const ArrayReadView* _array_read_view; private: - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; + vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) override; public: DotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector); ~DotProductByArrayReadViewExecutor(); @@ -216,7 +214,6 @@ public: template <typename A> class DotProductExecutor : public DotProductExecutorBase<typename A::BaseType> { public: - using AT = typename DotProductExecutorBase<typename A::BaseType>::AT; using V = typename DotProductExecutorBase<typename A::BaseType>::V; protected: const A * _attribute; @@ -232,8 +229,7 @@ public: DotProductByCopyExecutor(const A * attribute, const V & queryVector); ~DotProductByCopyExecutor(); private: - typedef typename DotProductExecutor<A>::AT AT; - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; + vespalib::ConstArrayRef<typename A::BaseType> getAttributeValues(uint32_t docid) final override; std::vector<typename A::BaseType> _copy; }; @@ -252,13 +248,12 @@ template <typename BaseType> class DotProductByContentFillExecutor : public DotProductExecutorBase<BaseType> { public: using V = typename DotProductExecutorBase<BaseType>::V; - using AT = typename DotProductExecutorBase<BaseType>::AT; using ValueFiller = attribute::AttributeContent<BaseType>; DotProductByContentFillExecutor(const attribute::IAttributeVector * attribute, const V & queryVector); ~DotProductByContentFillExecutor(); private: - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; + vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) final override; const attribute::IAttributeVector* _attribute; ValueFiller _filler; @@ -272,9 +267,8 @@ public: SparseDotProductExecutorBase(const V & queryVector, const IV & queryIndexes); ~SparseDotProductExecutorBase(); protected: - typedef typename DotProductExecutorBase<BaseType>::AT AT; IV _queryIndexes; - std::vector<AT> _scratch; + std::vector<BaseType> _scratch; }; template <typename BaseType> @@ -288,8 +282,7 @@ public: SparseDotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector, const IV & queryIndexes); ~SparseDotProductByArrayReadViewExecutor(); private: - typedef typename SparseDotProductExecutorBase<BaseType>::AT AT; - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; + vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) override; const ArrayReadView* _array_read_view; }; @@ -301,8 +294,7 @@ public: SparseDotProductByCopyExecutor(const A * attribute, const V & queryVector, const IV & queryIndexes); ~SparseDotProductByCopyExecutor(); private: - typedef typename SparseDotProductExecutorBase<typename A::BaseType>::AT AT; - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; + vespalib::ConstArrayRef<typename A::BaseType> getAttributeValues(uint32_t docid) final override; const A* _attribute; std::vector<typename A::BaseType> _copy; }; @@ -316,7 +308,6 @@ class SparseDotProductByContentFillExecutor : public DotProductExecutorBase<Base public: using IV = std::vector<uint32_t>; using V = typename DotProductExecutorBase<BaseType>::V; - using AT = typename DotProductExecutorBase<BaseType>::AT; using ValueFiller = attribute::AttributeContent<BaseType>; SparseDotProductByContentFillExecutor(const attribute::IAttributeVector * attribute, @@ -324,7 +315,7 @@ public: const IV & queryIndexes); ~SparseDotProductByContentFillExecutor() override; private: - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; + vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) final override; const attribute::IAttributeVector* _attribute; IV _queryIndexes; |