summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@oath.com>2018-07-25 18:05:12 +0200
committerHenning Baldersheim <balder@oath.com>2018-07-26 13:33:18 +0200
commit7b29d355ecda19506885868476f2a9f884469fd9 (patch)
tree610c2fb83722adec81dfe82da9a4e3bfcb6d22bc /searchlib
parentc4742fbc2329ddac79e5aa00f51c30f94b187932 (diff)
Make diversifier virtual for easier reuse and minimal runtime impact.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/attribute/diversity.h72
-rw-r--r--searchlib/src/vespa/searchlib/attribute/diversity.hpp20
2 files changed, 50 insertions, 42 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.h b/searchlib/src/vespa/searchlib/attribute/diversity.h
index 25dd5ef9d2e..90b499252c2 100644
--- a/searchlib/src/vespa/searchlib/attribute/diversity.h
+++ b/searchlib/src/vespa/searchlib/attribute/diversity.h
@@ -97,11 +97,20 @@ struct FetchFloat {
ValueType get(uint32_t docid) const { return attr.getFloat(docid); }
};
-template <typename Fetcher>
class DiversityFilter {
+public:
+ DiversityFilter(size_t max_total) : _max_total(max_total) {}
+ virtual ~DiversityFilter() {}
+ virtual bool accepted(uint32_t docId) = 0;
+ size_t getMaxTotal() const { return _max_total; }
+protected:
+ size_t _max_total;
+};
+
+template <typename Fetcher>
+class DiversityFilterT final : public DiversityFilter {
private:
size_t _total_count;
- size_t _max_total;
const Fetcher &_diversity;
size_t _max_per_group;
size_t _cutoff_max_groups;
@@ -110,39 +119,18 @@ private:
typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity;
Diversity _seen;
public:
- DiversityFilter(const Fetcher &diversity, size_t max_per_group,
+ DiversityFilterT(const Fetcher &diversity, size_t max_per_group,
size_t cutoff_max_groups, bool cutoff_strict, size_t max_total)
- : _total_count(0), _max_total(max_total), _diversity(diversity), _max_per_group(max_per_group),
+ : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group),
_cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict),
_seen(std::min(cutoff_max_groups, 10000ul)*3)
{ }
- size_t getMaxTotal() const { return _max_total; }
- bool accepted(uint32_t docId) {
- if (_total_count < _max_total) {
- if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) {
- typename Fetcher::ValueType group = _diversity.get(docId);
- if (_seen.size() < _cutoff_max_groups) {
- return conditional_add(_seen[group]);
- } else {
- auto found = _seen.find(group);
- if (found == _seen.end()) {
- add();
- return true;
- } else {
- return conditional_add(found->second);
- }
- }
- } else if ( !_cutoff_strict) {
- add();
- return true;
- }
- }
- return false;
- }
+ bool accepted(uint32_t docId) override;
private:
- void add() {
+ bool add() {
++_total_count;
+ return true;
}
bool conditional_add(uint32_t & group_count) {
if (group_count < _max_per_group) {
@@ -154,13 +142,13 @@ private:
}
};
-template <typename Filter, typename Result>
+template <typename Result>
class DiversityRecorder {
private:
- Filter & _filter;
+ DiversityFilter & _filter;
Result &_result;
public:
- DiversityRecorder(Filter & filter, Result &result)
+ DiversityRecorder(DiversityFilter & filter, Result &result)
: _filter(filter), _result(result)
{ }
@@ -173,12 +161,12 @@ public:
};
-template <typename DictRange, typename PostingStore, typename Filter, typename Result>
-void diversify_3(const DictRange &range_in, const PostingStore &posting, Filter & filter,
+template <typename DictRange, typename PostingStore, typename Result>
+void diversify_3(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter,
Result &result, std::vector<size_t> &fragments)
{
- DiversityRecorder<Filter, Result> recorder(filter, result);
+ DiversityRecorder<Result> recorder(filter, result);
DictRange range(range_in);
using DataType = typename PostingStore::DataType;
using KeyDataType = typename PostingStore::KeyDataType;
@@ -202,10 +190,10 @@ void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t
if (diversity_attr.hasEnum()) { // must handle enum first
FetchEnumFast fastEnum(diversity_attr);
if (fastEnum.valid()) {
- DiversityFilter<FetchEnumFast> filter(fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchEnumFast> filter(fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
} else {
- DiversityFilter<FetchEnum> filter(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchEnum> filter(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
}
} else if (diversity_attr.isIntegerType()) {
@@ -215,13 +203,13 @@ void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t
FetchInt32Fast fastInt32(diversity_attr);
FetchInt64Fast fastInt64(diversity_attr);
if (fastInt32.valid()) {
- DiversityFilter<FetchInt32Fast> filter(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchInt32Fast> filter(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
} else if (fastInt64.valid()) {
- DiversityFilter<FetchInt64Fast> filter(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchInt64Fast> filter(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
} else {
- DiversityFilter<FetchInteger> filter(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchInteger> filter(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
}
} else if (diversity_attr.isFloatingPointType()) {
@@ -230,13 +218,13 @@ void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t
FetchFloatFast fastFloat(diversity_attr);
FetchDoubleFast fastDouble(diversity_attr);
if (fastFloat.valid()) {
- DiversityFilter<FetchFloatFast> filter(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchFloatFast> filter(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
} else if (fastDouble.valid()) {
- DiversityFilter<FetchDoubleFast> filter(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchDoubleFast> filter(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
} else {
- DiversityFilter<FetchFloat> filter(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ DiversityFilterT<FetchFloat> filter(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
diversify_3(range_in, posting, filter, result, fragments);
}
}
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.hpp b/searchlib/src/vespa/searchlib/attribute/diversity.hpp
index 698f482dec1..52e6e5ede04 100644
--- a/searchlib/src/vespa/searchlib/attribute/diversity.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/diversity.hpp
@@ -31,4 +31,24 @@ ReverseRange<ITR>::ReverseRange(const ITR &lower, const ITR &upper)
template <typename ITR>
ReverseRange<ITR>::~ReverseRange() = default;
+template <typename Fetcher>
+bool
+DiversityFilterT<Fetcher>::accepted(uint32_t docId) {
+ if (_total_count < _max_total) {
+ if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) {
+ typename Fetcher::ValueType group = _diversity.get(docId);
+ if (_seen.size() < _cutoff_max_groups) {
+ return conditional_add(_seen[group]);
+ } else {
+ auto found = _seen.find(group);
+ return (found == _seen.end()) ? add() : conditional_add(found->second);
+ }
+ } else if ( !_cutoff_strict) {
+ return add();
+ }
+ }
+ return false;
+}
+
+
}