diff options
author | Geir Storli <geirst@yahooinc.com> | 2022-04-11 13:13:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-11 13:13:58 +0200 |
commit | c2f805788295a976a70dc03cd8a1dc6250029892 (patch) | |
tree | a9dfaf3022816c3a42bc624bb3ec43d879b4eca9 | |
parent | e5b9d3a634e2d8a1ecf783ff69c129e48732afad (diff) | |
parent | 8119178033ca4f61f847e9f822699e3f7505ccfd (diff) |
Merge pull request #22080 from vespa-engine/toregge/use-multi-value-read-view-api-in-dot-product-feature-instead-of-getrawvalues
Use IMultiValueReadView in dot product feature instead of getRawValues().
3 files changed, 137 insertions, 98 deletions
diff --git a/searchlib/src/tests/features/prod_features.cpp b/searchlib/src/tests/features/prod_features.cpp index 7ebc3759813..002cfdc7f3b 100644 --- a/searchlib/src/tests/features/prod_features.cpp +++ b/searchlib/src/tests/features/prod_features.cpp @@ -1356,8 +1356,8 @@ Test::testDotProduct() TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsstr", "{a:1,b:2}", "search::features::dotproduct::wset::(anonymous namespace)::DotProductExecutorByEnum")); TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsstr", "{a:1}", "search::features::dotproduct::wset::(anonymous namespace)::SingleDotProductExecutorByEnum")); TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsstr", "{unknown:1}", "search::features::SingleZeroValueExecutor")); - TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsint", "{1:1, 2:3}", "search::features::dotproduct::wset::DotProductExecutor<search::MultiValueNumericAttribute<search::IntegerAttributeTemplate<int>, search::multivalue::WeightedValue<int> > >")); - TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsint", "{1:1}", "search::features::dotproduct::wset::(anonymous namespace)::SingleDotProductExecutorByValue<search::MultiValueNumericAttribute<search::IntegerAttributeTemplate<int>, search::multivalue::WeightedValue<int> > >")); + TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsint", "{1:1, 2:3}", "search::features::dotproduct::wset::DotProductByWeightedSetReadViewExecutor<int>")); + TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsint", "{1:1}", "search::features::dotproduct::wset::(anonymous namespace)::SingleDotProductByWeightedValueExecutor<int>")); TEST_DO(verifyCorrectDotProductExecutor(_factory, "wsint", "{}", "search::features::SingleZeroValueExecutor")); } diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp index b18c6687561..d872280a8d3 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp @@ -122,32 +122,30 @@ void DotProductExecutorBase<BaseType>::execute(uint32_t docId) { outputs().set_number(0, val); } -template <typename A> -DotProductExecutor<A>::DotProductExecutor(const A * attribute, const V & queryVector) : - DotProductExecutorBase<typename A::BaseType>(queryVector), - _attribute(attribute), +template <typename BaseType> +DotProductByWeightedSetReadViewExecutor<BaseType>::DotProductByWeightedSetReadViewExecutor(const WeightedSetReadView* weighted_set_read_view, const V & queryVector) : + DotProductExecutorBase<BaseType>(queryVector), + _weighted_set_read_view(weighted_set_read_view), _backing() { } -template <typename A> -DotProductExecutor<A>::DotProductExecutor(const A * attribute, std::unique_ptr<V> queryVector) : - DotProductExecutorBase<typename A::BaseType>(*queryVector), - _attribute(attribute), +template <typename BaseType> +DotProductByWeightedSetReadViewExecutor<BaseType>::DotProductByWeightedSetReadViewExecutor(const WeightedSetReadView* weighted_set_read_view, std::unique_ptr<V> queryVector) : + DotProductExecutorBase<BaseType>(*queryVector), + _weighted_set_read_view(weighted_set_read_view), _backing(std::move(queryVector)) { } -template <typename A> -DotProductExecutor<A>::~DotProductExecutor() = default; +template <typename BaseType> +DotProductByWeightedSetReadViewExecutor<BaseType>::~DotProductByWeightedSetReadViewExecutor() = default; -template <typename A> -vespalib::ConstArrayRef<typename DotProductExecutor<A>::AT> -DotProductExecutor<A>::getAttributeValues(uint32_t docId) +template <typename BaseType> +vespalib::ConstArrayRef<typename DotProductByWeightedSetReadViewExecutor<BaseType>::AT> +DotProductByWeightedSetReadViewExecutor<BaseType>::getAttributeValues(uint32_t docId) { - const AT* values = nullptr; - auto size = _attribute->getRawValues(docId, values); - return vespalib::ConstArrayRef(values, size); + return _weighted_set_read_view->get_raw_values(docId); } namespace { @@ -227,19 +225,19 @@ private: feature_t _value; }; -template <typename A> -class SingleDotProductExecutorByValue final : public fef::FeatureExecutor { +template <typename BaseType> +class SingleDotProductByWeightedValueExecutor final : public fef::FeatureExecutor { public: - SingleDotProductExecutorByValue(const A * attribute, typename A::BaseType key, feature_t value) - : _attribute(attribute), + using WeightedSetReadView = attribute::IWeightedSetReadView<BaseType>; + SingleDotProductByWeightedValueExecutor(const WeightedSetReadView * weighted_set_read_view, BaseType key, feature_t value) + : _weighted_set_read_view(weighted_set_read_view), _key(key), _value(value) {} void execute(uint32_t docId) override { - const multivalue::WeightedValue<typename A::BaseType> *values(nullptr); - uint32_t sz = _attribute->getRawValues(docId, values); - for (size_t i = 0; i < sz; ++i) { + auto values = _weighted_set_read_view->get_raw_values(docId); + for (size_t i = 0; i < values.size(); ++i) { if (values[i].value() == _key) { outputs().set_number(0, values[i].weight() * _value); return; @@ -248,9 +246,9 @@ public: outputs().set_number(0, 0); } private: - const A * _attribute; - typename A::BaseType _key; - feature_t _value; + const WeightedSetReadView* _weighted_set_read_view; + BaseType _key; + feature_t _value; }; } @@ -279,6 +277,23 @@ void DotProductExecutorBase<BaseType>::execute(uint32_t docId) { &_queryVector[0], reinterpret_cast<const multivalue::ValueType_t<AT> *>(values.data()), commonRange)); } +template <typename BaseType> +DotProductByArrayReadViewExecutor<BaseType>::DotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector) : + DotProductExecutorBase<BaseType>(queryVector), + _array_read_view(array_read_view) +{ +} + +template <typename BaseType> +DotProductByArrayReadViewExecutor<BaseType>::~DotProductByArrayReadViewExecutor() = default; + +template <typename BaseType> +vespalib::ConstArrayRef<typename DotProductByArrayReadViewExecutor<BaseType>::AT> +DotProductByArrayReadViewExecutor<BaseType>::getAttributeValues(uint32_t docId) +{ + return _array_read_view->get_raw_values(docId); +} + template <typename A> DotProductExecutor<A>::DotProductExecutor(const A * attribute, const V & queryVector) : DotProductExecutorBase<typename A::BaseType>(queryVector), @@ -289,34 +304,34 @@ DotProductExecutor<A>::DotProductExecutor(const A * attribute, const V & queryVe template <typename A> DotProductExecutor<A>::~DotProductExecutor() = default; -template <typename A> -vespalib::ConstArrayRef<typename DotProductExecutor<A>::AT> -DotProductExecutor<A>::getAttributeValues(uint32_t docId) +template <typename BaseType> +SparseDotProductExecutorBase<BaseType>::SparseDotProductExecutorBase(const V & queryVector, const IV & queryIndexes) : + DotProductExecutorBase<BaseType>(queryVector), + _queryIndexes(queryIndexes), + _scratch(queryIndexes.size()) { - const AT* values = nullptr; - auto size = _attribute->getRawValues(docId, values); - return vespalib::ConstArrayRef(values, size); } -template <typename A> -SparseDotProductExecutor<A>::SparseDotProductExecutor(const A * attribute, const V & queryVector, const IV & queryIndexes) : - DotProductExecutor<A>(attribute, queryVector), - _queryIndexes(queryIndexes), - _scratch(queryIndexes.size()) +template <typename BaseType> +SparseDotProductExecutorBase<BaseType>::~SparseDotProductExecutorBase() = default; + +template <typename BaseType> +SparseDotProductByArrayReadViewExecutor<BaseType>::SparseDotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector, const IV & queryIndexes) + : SparseDotProductExecutorBase<BaseType>(queryVector, queryIndexes), + _array_read_view(array_read_view) { } -template <typename A> -SparseDotProductExecutor<A>::~SparseDotProductExecutor() = default; +template <typename BaseType> +SparseDotProductByArrayReadViewExecutor<BaseType>::~SparseDotProductByArrayReadViewExecutor() = default; -template <typename A> -vespalib::ConstArrayRef<typename SparseDotProductExecutor<A>::AT> -SparseDotProductExecutor<A>::getAttributeValues(uint32_t docId) +template <typename BaseType> +vespalib::ConstArrayRef<typename SparseDotProductByArrayReadViewExecutor<BaseType>::AT> +SparseDotProductByArrayReadViewExecutor<BaseType>::getAttributeValues(uint32_t docid) { - const AT* allValues(nullptr); - size_t count = this->_attribute->getRawValues(docId, allValues); + auto allValues = _array_read_view->get_raw_values(docid); size_t i(0); - for (; (i < _queryIndexes.size()) && (_queryIndexes[i] < count); i++) { + for (; (i < _queryIndexes.size()) && (_queryIndexes[i] < allValues.size()); i++) { _scratch[i] = allValues[_queryIndexes[i]]; } return vespalib::ConstArrayRef(_scratch.data(), i); @@ -346,7 +361,8 @@ DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId) template <typename A> SparseDotProductByCopyExecutor<A>::SparseDotProductByCopyExecutor(const A * attribute, const V & queryVector, const IV & queryIndexes) : - SparseDotProductExecutor<A>(attribute, queryVector, queryIndexes), + SparseDotProductExecutorBase<typename A::BaseType>(queryVector, queryIndexes), + _attribute(attribute), _copy(std::max(static_cast<size_t>(attribute->getMaxValueCount()), queryIndexes.size())) { } @@ -535,16 +551,15 @@ namespace { using dotproduct::ArrayParam; -template <typename A, typename B> -bool supportsGetRawValues(const A & attr) noexcept { - try { - const B * tmp = nullptr; - attr.getRawValues(0, tmp); // Throws if unsupported - return true; - } catch (const std::runtime_error & e) { - (void) e; - return false; +template <typename AT> +const attribute::IMultiValueReadView<AT>* +get_multi_value_read_view(const IAttributeVector& attribute) +{ + auto multi_value_attribute = attribute.as_multi_value_attribute(); + if (multi_value_attribute != nullptr) { + return multi_value_attribute->as_read_view(attribute::IMultiValueAttribute::Tag<AT>()); } + return nullptr; } bool supportsGetEnumHandles(const IWeightedIndexVector * attr) noexcept { @@ -574,21 +589,16 @@ createForDirectArrayImpl(const IAttributeVector * attribute, const A * iattr = dynamic_cast<const A *>(attribute); using T = typename A::BaseType; using VT = multivalue::Value<T>; + auto array_read_view = get_multi_value_read_view<VT>(*attribute); if (indexes.empty()) { - if (supportsGetRawValues<A,VT>(*iattr)) { - using ExactA = MultiValueNumericAttribute<A, VT>; - - auto * exactA = dynamic_cast<const ExactA *>(iattr); - if (exactA != nullptr) { - return stash.create<dotproduct::array::DotProductExecutor<ExactA>>(exactA, values); - } - return stash.create<dotproduct::array::DotProductExecutor<A>>(iattr, values); + if (array_read_view != nullptr) { + return stash.create<dotproduct::array::DotProductByArrayReadViewExecutor<T>>(array_read_view, values); } else { return stash.create<dotproduct::array::DotProductByCopyExecutor<A>>(iattr, values); } } else { - if (supportsGetRawValues<A, VT>(*iattr)) { - return stash.create<dotproduct::array::SparseDotProductExecutor<A>>(iattr, values, indexes); + if (array_read_view != nullptr) { + return stash.create<dotproduct::array::SparseDotProductByArrayReadViewExecutor<T>>(array_read_view, values, indexes); } else { return stash.create<dotproduct::array::SparseDotProductByCopyExecutor<A>>(iattr, values, indexes); } @@ -681,17 +691,13 @@ createForDirectWSetImpl(const IAttributeVector * attribute, V && vector, vespali using T = typename A::BaseType; const A * iattr = dynamic_cast<const A *>(attribute); using VT = multivalue::WeightedValue<T>; - using ExactA = MultiValueNumericAttribute<A, VT>; - if (!attribute->isImported() && (iattr != nullptr) && supportsGetRawValues<A, VT>(*iattr)) { - auto * exactA = dynamic_cast<const ExactA *>(iattr); - if (exactA != nullptr) { - if (extractSize(vector) == 1) { - auto elem = extractElem(vector, 0ul); - return stash.create<SingleDotProductExecutorByValue<ExactA>>(exactA, elem.first, elem.second); - } - return stash.create<DotProductExecutor<ExactA>>(exactA, std::forward<V>(vector)); + auto weighted_set_read_view = get_multi_value_read_view<VT>(*attribute); + if (!attribute->isImported() && (iattr != nullptr) && weighted_set_read_view != nullptr) { + if (extractSize(vector) == 1) { + auto elem = extractElem(vector, 0ul); + return stash.create<SingleDotProductByWeightedValueExecutor<T>>(weighted_set_read_view, elem.first, elem.second); } - return stash.create<DotProductExecutor<A>>(iattr, std::forward<V>(vector)); + return stash.create<DotProductByWeightedSetReadViewExecutor<T>>(weighted_set_read_view, std::forward<V>(vector)); } return stash.create<DotProductExecutorByCopy<IntegerVectorT<T>, WeightedIntegerContent>>(attribute, std::forward<V>(vector)); } diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h index 916efa54e11..4e94c9bce86 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h @@ -4,6 +4,7 @@ #include "utils.h" #include <vespa/searchcommon/attribute/attributecontent.h> +#include <vespa/searchcommon/attribute/i_multi_value_read_view.h> #include <vespa/searchcommon/attribute/multivalue.h> #include <vespa/searchlib/fef/blueprint.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> @@ -131,20 +132,21 @@ public: void execute(uint32_t docId) override; }; -template <typename A> -class DotProductExecutor final : public DotProductExecutorBase<typename A::BaseType> { +template <typename BaseType> +class DotProductByWeightedSetReadViewExecutor final : public DotProductExecutorBase<BaseType> { public: - using AT = typename DotProductExecutorBase<typename A::BaseType>::AT; - using V = typename DotProductExecutorBase<typename A::BaseType>::V; + using WeightedSetReadView = attribute::IWeightedSetReadView<BaseType>; + using AT = typename DotProductExecutorBase<BaseType>::AT; + using V = typename DotProductExecutorBase<BaseType>::V; protected: - const A * _attribute; + const WeightedSetReadView * _weighted_set_read_view; private: std::unique_ptr<V> _backing; vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; public: - DotProductExecutor(const A * attribute, const V & queryVector); - DotProductExecutor(const A * attribute, std::unique_ptr<V> queryVector); - ~DotProductExecutor(); + DotProductByWeightedSetReadViewExecutor(const WeightedSetReadView* weighted_set_read_view, const V & queryVector); + DotProductByWeightedSetReadViewExecutor(const WeightedSetReadView * weighted_set_read_view, std::unique_ptr<V> queryVector); + ~DotProductByWeightedSetReadViewExecutor(); }; @@ -191,6 +193,24 @@ public: }; /** + * Implements the executor for the dotproduct feature using array read view. + */ +template <typename BaseType> +class DotProductByArrayReadViewExecutor : public DotProductExecutorBase<BaseType> { +public: + using AT = typename DotProductExecutorBase<BaseType>::AT; + using V = typename DotProductExecutorBase<BaseType>::V; + using ArrayReadView = attribute::IArrayReadView<BaseType>; +protected: + const ArrayReadView* _array_read_view; +private: + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; +public: + DotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector); + ~DotProductByArrayReadViewExecutor(); +}; + +/** * Implements the executor for the dotproduct feature. */ template <typename A> @@ -200,8 +220,6 @@ public: using V = typename DotProductExecutorBase<typename A::BaseType>::V; protected: const A * _attribute; -private: - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; public: DotProductExecutor(const A * attribute, const V & queryVector); ~DotProductExecutor(); @@ -246,31 +264,46 @@ private: ValueFiller _filler; }; -template <typename A> -class SparseDotProductExecutor : public DotProductExecutor<A> { +template <typename BaseType> +class SparseDotProductExecutorBase : public DotProductExecutorBase<BaseType> { public: typedef std::vector<uint32_t> IV; - typedef typename DotProductExecutor<A>::V V; - SparseDotProductExecutor(const A * attribute, const V & queryVector, const IV & queryIndexes); - ~SparseDotProductExecutor(); -private: - typedef typename DotProductExecutor<A>::AT AT; - vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; + typedef typename DotProductExecutorBase<BaseType>::V V; + SparseDotProductExecutorBase(const V & queryVector, const IV & queryIndexes); + ~SparseDotProductExecutorBase(); protected: + typedef typename DotProductExecutorBase<BaseType>::AT AT; IV _queryIndexes; std::vector<AT> _scratch; }; +template <typename BaseType> +class SparseDotProductByArrayReadViewExecutor : public SparseDotProductExecutorBase<BaseType> { +public: + using SparseDotProductExecutorBase<BaseType>::_queryIndexes; + using SparseDotProductExecutorBase<BaseType>::_scratch; + typedef std::vector<uint32_t> IV; + typedef typename SparseDotProductExecutorBase<BaseType>::V V; + using ArrayReadView = attribute::IArrayReadView<BaseType>; + SparseDotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector, const IV & queryIndexes); + ~SparseDotProductByArrayReadViewExecutor(); +private: + typedef typename SparseDotProductExecutorBase<BaseType>::AT AT; + vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override; + const ArrayReadView* _array_read_view; +}; + template <typename A> -class SparseDotProductByCopyExecutor : public SparseDotProductExecutor<A> { +class SparseDotProductByCopyExecutor : public SparseDotProductExecutorBase<typename A::BaseType> { public: typedef std::vector<uint32_t> IV; - typedef typename DotProductExecutor<A>::V V; + typedef typename SparseDotProductExecutorBase<typename A::BaseType>::V V; SparseDotProductByCopyExecutor(const A * attribute, const V & queryVector, const IV & queryIndexes); ~SparseDotProductByCopyExecutor(); private: - typedef typename DotProductExecutor<A>::AT AT; + typedef typename SparseDotProductExecutorBase<typename A::BaseType>::AT AT; vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override; + const A* _attribute; std::vector<typename A::BaseType> _copy; }; |