diff options
author | Henning Baldersheim <balder@oath.com> | 2018-07-25 18:05:12 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@oath.com> | 2018-07-26 13:33:18 +0200 |
commit | 7b29d355ecda19506885868476f2a9f884469fd9 (patch) | |
tree | 610c2fb83722adec81dfe82da9a4e3bfcb6d22bc /searchlib | |
parent | c4742fbc2329ddac79e5aa00f51c30f94b187932 (diff) |
Make diversifier virtual for easier reuse and minimal runtime impact.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/attribute/diversity.h | 72 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/attribute/diversity.hpp | 20 |
2 files changed, 50 insertions, 42 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.h b/searchlib/src/vespa/searchlib/attribute/diversity.h index 25dd5ef9d2e..90b499252c2 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.h +++ b/searchlib/src/vespa/searchlib/attribute/diversity.h @@ -97,11 +97,20 @@ struct FetchFloat { ValueType get(uint32_t docid) const { return attr.getFloat(docid); } }; -template <typename Fetcher> class DiversityFilter { +public: + DiversityFilter(size_t max_total) : _max_total(max_total) {} + virtual ~DiversityFilter() {} + virtual bool accepted(uint32_t docId) = 0; + size_t getMaxTotal() const { return _max_total; } +protected: + size_t _max_total; +}; + +template <typename Fetcher> +class DiversityFilterT final : public DiversityFilter { private: size_t _total_count; - size_t _max_total; const Fetcher &_diversity; size_t _max_per_group; size_t _cutoff_max_groups; @@ -110,39 +119,18 @@ private: typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity; Diversity _seen; public: - DiversityFilter(const Fetcher &diversity, size_t max_per_group, + DiversityFilterT(const Fetcher &diversity, size_t max_per_group, size_t cutoff_max_groups, bool cutoff_strict, size_t max_total) - : _total_count(0), _max_total(max_total), _diversity(diversity), _max_per_group(max_per_group), + : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group), _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict), _seen(std::min(cutoff_max_groups, 10000ul)*3) { } - size_t getMaxTotal() const { return _max_total; } - bool accepted(uint32_t docId) { - if (_total_count < _max_total) { - if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { - typename Fetcher::ValueType group = _diversity.get(docId); - if (_seen.size() < _cutoff_max_groups) { - return conditional_add(_seen[group]); - } else { - auto found = _seen.find(group); - if (found == _seen.end()) { - add(); - return true; - } else { - return conditional_add(found->second); - } - } - } else if ( !_cutoff_strict) { - add(); - return true; - } - } - return false; - } + bool accepted(uint32_t docId) override; private: - void add() { + bool add() { ++_total_count; + return true; } bool conditional_add(uint32_t & group_count) { if (group_count < _max_per_group) { @@ -154,13 +142,13 @@ private: } }; -template <typename Filter, typename Result> +template <typename Result> class DiversityRecorder { private: - Filter & _filter; + DiversityFilter & _filter; Result &_result; public: - DiversityRecorder(Filter & filter, Result &result) + DiversityRecorder(DiversityFilter & filter, Result &result) : _filter(filter), _result(result) { } @@ -173,12 +161,12 @@ public: }; -template <typename DictRange, typename PostingStore, typename Filter, typename Result> -void diversify_3(const DictRange &range_in, const PostingStore &posting, Filter & filter, +template <typename DictRange, typename PostingStore, typename Result> +void diversify_3(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter, Result &result, std::vector<size_t> &fragments) { - DiversityRecorder<Filter, Result> recorder(filter, result); + DiversityRecorder<Result> recorder(filter, result); DictRange range(range_in); using DataType = typename PostingStore::DataType; using KeyDataType = typename PostingStore::KeyDataType; @@ -202,10 +190,10 @@ void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t if (diversity_attr.hasEnum()) { // must handle enum first FetchEnumFast fastEnum(diversity_attr); if (fastEnum.valid()) { - DiversityFilter<FetchEnumFast> filter(fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchEnumFast> filter(fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } else { - DiversityFilter<FetchEnum> filter(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchEnum> filter(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } } else if (diversity_attr.isIntegerType()) { @@ -215,13 +203,13 @@ void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t FetchInt32Fast fastInt32(diversity_attr); FetchInt64Fast fastInt64(diversity_attr); if (fastInt32.valid()) { - DiversityFilter<FetchInt32Fast> filter(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchInt32Fast> filter(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } else if (fastInt64.valid()) { - DiversityFilter<FetchInt64Fast> filter(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchInt64Fast> filter(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } else { - DiversityFilter<FetchInteger> filter(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchInteger> filter(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } } else if (diversity_attr.isFloatingPointType()) { @@ -230,13 +218,13 @@ void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t FetchFloatFast fastFloat(diversity_attr); FetchDoubleFast fastDouble(diversity_attr); if (fastFloat.valid()) { - DiversityFilter<FetchFloatFast> filter(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchFloatFast> filter(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } else if (fastDouble.valid()) { - DiversityFilter<FetchDoubleFast> filter(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchDoubleFast> filter(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } else { - DiversityFilter<FetchFloat> filter(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + DiversityFilterT<FetchFloat> filter(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); diversify_3(range_in, posting, filter, result, fragments); } } diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.hpp b/searchlib/src/vespa/searchlib/attribute/diversity.hpp index 698f482dec1..52e6e5ede04 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.hpp +++ b/searchlib/src/vespa/searchlib/attribute/diversity.hpp @@ -31,4 +31,24 @@ ReverseRange<ITR>::ReverseRange(const ITR &lower, const ITR &upper) template <typename ITR> ReverseRange<ITR>::~ReverseRange() = default; +template <typename Fetcher> +bool +DiversityFilterT<Fetcher>::accepted(uint32_t docId) { + if (_total_count < _max_total) { + if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { + typename Fetcher::ValueType group = _diversity.get(docId); + if (_seen.size() < _cutoff_max_groups) { + return conditional_add(_seen[group]); + } else { + auto found = _seen.find(group); + return (found == _seen.end()) ? add() : conditional_add(found->second); + } + } else if ( !_cutoff_strict) { + return add(); + } + } + return false; +} + + } |