diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-04-21 18:17:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-21 18:17:17 +0200 |
commit | 56e3b7942d0b838bcb5ca4c2ab36fdf7b0fb9c81 (patch) | |
tree | d22f02357d1ca0079a7f62c0c3ebbaf9cacdaff6 /searchlib | |
parent | 49c7cfc737300dd4f475cc0b07f4f703aa789d46 (diff) | |
parent | 7244cb4291c5e34b722dc2bf0a8823588e59e3f7 (diff) |
Merge pull request #22199 from vespa-engine/toregge/always-use-multi-value-read-view-for-dot-product-feature-on-weighted-set-attribute
Always use MultiValueReadView for dot product feature on weighted set attributes.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.cpp | 81 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.h | 30 |
2 files changed, 34 insertions, 77 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp index 8f75b6ecc7d..1ee5fe90773 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp @@ -52,49 +52,6 @@ template class VectorBase<uint32_t, uint32_t, double>; template class IntegerVectorT<int64_t>; - -template <typename Vector, typename Buffer> -DotProductExecutorByCopy<Vector, Buffer>::DotProductExecutorByCopy(const IAttributeVector * attribute, const Vector & queryVector) : - FeatureExecutor(), - _attribute(attribute), - _queryVector(queryVector), - _end(_queryVector.getDimMap().end()), - _buffer(), - _backing() -{ - _buffer.allocate(_attribute->getMaxValueCount()); -} - -template <typename Vector, typename Buffer> -DotProductExecutorByCopy<Vector, Buffer>::DotProductExecutorByCopy(const IAttributeVector * attribute, std::unique_ptr<Vector> queryVector) : - FeatureExecutor(), - _attribute(attribute), - _queryVector(*queryVector), - _end(_queryVector.getDimMap().end()), - _buffer(), - _backing(std::move(queryVector)) -{ - _buffer.allocate(_attribute->getMaxValueCount()); -} - -template <typename Vector, typename Buffer> -DotProductExecutorByCopy<Vector, Buffer>::~DotProductExecutorByCopy() = default; - -template <typename Vector, typename Buffer> -void -DotProductExecutorByCopy<Vector, Buffer>::execute(uint32_t docId) -{ - feature_t val = 0; - _buffer.fill(*_attribute, docId); - for (size_t i = 0; i < _buffer.size(); ++i) { - auto itr = _queryVector.getDimMap().find(_buffer[i].getValue()); - if (itr != _end) { - val += _buffer[i].getWeight() * itr->second; - } - } - outputs().set_number(0, val); -} - StringVector::StringVector() = default; StringVector::~StringVector() = default; @@ -229,6 +186,7 @@ template <typename BaseType> class SingleDotProductByWeightedValueExecutor final : public fef::FeatureExecutor { public: using WeightedSetReadView = attribute::IWeightedSetReadView<BaseType>; + using StoredKeyType = std::conditional_t<std::is_same_v<BaseType,const char*>,vespalib::string,BaseType>; SingleDotProductByWeightedValueExecutor(const WeightedSetReadView * weighted_set_read_view, BaseType key, feature_t value) : _weighted_set_read_view(weighted_set_read_view), _key(key), @@ -247,7 +205,7 @@ public: } private: const WeightedSetReadView* _weighted_set_read_view; - BaseType _key; + StoredKeyType _key; feature_t _value; }; @@ -650,23 +608,38 @@ std::pair<T, feature_t> extractElem(const std::unique_ptr<dotproduct::wset::Inte return extractElem(*v, idx); } -template <typename A, typename V> +size_t extractSize(const dotproduct::wset::StringVector& v) { + return v.getVector().size(); +} + +std::pair<const char*, feature_t> extractElem(const dotproduct::wset::StringVector& v, size_t idx) { + const auto & pair = v.getVector()[idx]; + return std::pair<const char*, feature_t>(pair.first.c_str(), pair.second); +} + +size_t extractSize(const std::unique_ptr<dotproduct::wset::StringVector>& v) { + return extractSize(*v); +} + +std::pair<const char*, feature_t> extractElem(const std::unique_ptr<dotproduct::wset::StringVector>& v, size_t idx) { + return extractElem(*v, idx); +} + +template <typename T, typename V> FeatureExecutor & createForDirectWSetImpl(const IAttributeVector * attribute, V && vector, vespalib::Stash & stash) { using namespace dotproduct::wset; - using T = typename A::BaseType; - const A * iattr = dynamic_cast<const A *>(attribute); using VT = multivalue::WeightedValue<T>; auto weighted_set_read_view = make_multi_value_read_view<VT>(*attribute, stash); - if (!attribute->isImported() && (iattr != nullptr) && weighted_set_read_view != nullptr) { + if (weighted_set_read_view != nullptr) { if (extractSize(vector) == 1) { auto elem = extractElem(vector, 0ul); return stash.create<SingleDotProductByWeightedValueExecutor<T>>(weighted_set_read_view, elem.first, elem.second); } return stash.create<DotProductByWeightedSetReadViewExecutor<T>>(weighted_set_read_view, std::forward<V>(vector)); } - return stash.create<DotProductExecutorByCopy<IntegerVectorT<T>, WeightedIntegerContent>>(attribute, std::forward<V>(vector)); + return stash.create<SingleZeroValueExecutor>(); } template <typename T> @@ -676,7 +649,7 @@ createForDirectIntegerWSet(const IAttributeVector * attribute, const dotproduct: using namespace dotproduct::wset; return vector.empty() ? stash.create<SingleZeroValueExecutor>() - : createForDirectWSetImpl<IntegerAttributeTemplate<T>>(attribute, vector, stash); + : createForDirectWSetImpl<T>(attribute, vector, stash); } FeatureExecutor & @@ -727,14 +700,13 @@ createFromObject(const IAttributeVector * attribute, const fef::Anything & objec } return stash.create<DotProductExecutorByEnum>(weighted_set_enum_read_view, vector); } - return stash.create<DotProductExecutorByCopy<EnumVector, WeightedEnumContent>>(attribute, vector); } else { if (attribute->isStringType()) { const auto & vector = dynamic_cast<const StringVector &>(object); if (vector.empty()) { return stash.create<SingleZeroValueExecutor>(); } - return stash.create<DotProductExecutorByCopy<StringVector, WeightedConstCharContent>>(attribute, vector); + return createForDirectWSetImpl<const char*>(attribute, vector, stash); } else if (attribute->isIntegerType()) { if (attribute->getBasicType() == BasicType::INT32) { return createForDirectIntegerWSet<int32_t>(attribute, dynamic_cast<const IntegerVectorT<int32_t> &>(object), stash); @@ -800,7 +772,7 @@ createForDirectIntegerWSet(const IAttributeVector * attribute, const Property & vector->syncMap(); return vector->empty() ? stash.create<SingleZeroValueExecutor>() - : createForDirectWSetImpl<IntegerAttributeTemplate<T>>(attribute, std::move(vector), stash); + : createForDirectWSetImpl<T>(attribute, std::move(vector), stash); } FeatureExecutor & @@ -822,7 +794,6 @@ createTypedWsetExecutor(const IAttributeVector * attribute, const Property & pro } return stash.create<DotProductExecutorByEnum>(weighted_set_enum_read_view, std::move(vector)); } - return stash.create<DotProductExecutorByCopy<EnumVector, WeightedEnumContent>>(attribute, std::move(vector)); } else { if (attribute->isStringType()) { auto vector = std::make_unique<StringVector>(); @@ -831,7 +802,7 @@ createTypedWsetExecutor(const IAttributeVector * attribute, const Property & pro return stash.create<SingleZeroValueExecutor>(); } vector->syncMap(); - return stash.create<DotProductExecutorByCopy<StringVector, WeightedConstCharContent>>(attribute, std::move(vector)); + return createForDirectWSetImpl<const char*>(attribute, std::move(vector), stash); } else if (attribute->isIntegerType()) { if (attribute->getBasicType() == BasicType::INT32) { return createForDirectIntegerWSet<int32_t>(attribute, prop, stash); diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h index 051a08025d8..ee0a85e689b 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h @@ -65,11 +65,14 @@ public: bool empty() const { return _vector.empty(); } }; +template <typename T> +using NumericVectorBaseT = VectorBase<T, T, feature_t>; + /** * Represents a vector where the dimensions are integers. **/ template<typename T> -class IntegerVectorT : public VectorBase<T, T, feature_t> { +class IntegerVectorT : public NumericVectorBaseT<T> { public: void insert(vespalib::stringref label, vespalib::stringref value) { this->_vector.emplace_back(util::strToNum<T>(label), util::strToNum<feature_t>(value)); @@ -82,10 +85,12 @@ extern template class IntegerVectorT<int64_t>; using IntegerVector = IntegerVectorT<int64_t>; +using StringVectorBase = VectorBase<vespalib::string, const char*, feature_t, ConstCharComparator>; + /** * Represents a vector where the dimensions are string values. **/ -class StringVector : public VectorBase<vespalib::string, const char *, feature_t, ConstCharComparator> { +class StringVector : public StringVectorBase { public: StringVector(); StringVector(StringVector &&) = default; @@ -121,7 +126,7 @@ template <typename BaseType> class DotProductExecutorBase : public fef::FeatureExecutor { public: using AT = multivalue::WeightedValue<BaseType>; - using V = VectorBase<BaseType, BaseType, feature_t>; + using V = std::conditional_t<std::is_same_v<BaseType,const char*>,StringVectorBase,NumericVectorBaseT<BaseType>>; private: const V & _queryVector; const typename V::HashMap::const_iterator _end; @@ -149,25 +154,6 @@ public: ~DotProductByWeightedSetReadViewExecutor(); }; - -/** - * Implements the executor for the dotproduct feature. - */ -template <typename Vector, typename Buffer> -class DotProductExecutorByCopy final : public fef::FeatureExecutor { -private: - const attribute::IAttributeVector * _attribute; - const Vector & _queryVector; - const typename Vector::HashMap::const_iterator _end; - Buffer _buffer; - std::unique_ptr<Vector> _backing; -public: - DotProductExecutorByCopy(const attribute::IAttributeVector * attribute, const Vector & queryVector); - DotProductExecutorByCopy(const attribute::IAttributeVector * attribute, std::unique_ptr<Vector> queryVector); - ~DotProductExecutorByCopy() override; - void execute(uint32_t docId) override; -}; - } namespace array { |