summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/src/vespa/searchlib/attribute/diversity.cpp136
-rw-r--r--searchlib/src/vespa/searchlib/attribute/diversity.h161
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/idiversifier.h16
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/isourceselector.h8
5 files changed, 185 insertions, 143 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.cpp b/searchlib/src/vespa/searchlib/attribute/diversity.cpp
index 4c6fe054b12..65a84baa212 100644
--- a/searchlib/src/vespa/searchlib/attribute/diversity.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/diversity.cpp
@@ -1,7 +1,143 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "diversity.hpp"
+#include "singleenumattribute.h"
+#include "singlenumericattribute.h"
+#include <vespa/vespalib/stllike/hash_map.h>
+using std::make_unique;
namespace search::attribute::diversity {
+template <typename T>
+struct FetchNumberFast {
+ const T * const attr;
+ typedef typename T::LoadedValueType ValueType;
+ FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {}
+ ValueType get(uint32_t docid) const { return attr->getFast(docid); }
+ bool valid() const { return (attr != nullptr); }
+};
+
+struct FetchEnumFast {
+ const SingleValueEnumAttributeBase * const attr;
+ typedef uint32_t ValueType;
+ FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {}
+ ValueType get(uint32_t docid) const { return attr->getE(docid); }
+ bool valid() const { return (attr != nullptr); }
+};
+
+struct FetchEnum {
+ const IAttributeVector *attr;
+ typedef uint32_t ValueType;
+ FetchEnum(const IAttributeVector & attr_in) : attr(&attr_in) {}
+ ValueType get(uint32_t docid) const { return attr->getEnum(docid); }
+};
+
+struct FetchInteger {
+ const IAttributeVector * attr;
+ typedef int64_t ValueType;
+ FetchInteger(const IAttributeVector & attr_in) : attr(&attr_in) {}
+ ValueType get(uint32_t docid) const { return attr->getInt(docid); }
+};
+
+struct FetchFloat {
+ const IAttributeVector * attr;
+ typedef double ValueType;
+ FetchFloat(const IAttributeVector & attr_in) : attr(&attr_in) {}
+ ValueType get(uint32_t docid) const { return attr->getFloat(docid); }
+};
+
+template <typename Fetcher>
+class DiversityFilterT final : public DiversityFilter {
+private:
+ size_t _total_count;
+ Fetcher _diversity;
+ size_t _max_per_group;
+ size_t _cutoff_max_groups;
+ bool _cutoff_strict;
+
+ typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity;
+ Diversity _seen;
+public:
+ DiversityFilterT(Fetcher diversity, size_t max_per_group, size_t cutoff_max_groups,
+ bool cutoff_strict, size_t max_total)
+ : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group),
+ _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict),
+ _seen(std::min(cutoff_max_groups, 10000ul)*3)
+ { }
+
+ bool accepted(uint32_t docId) override;
+private:
+ bool add() {
+ ++_total_count;
+ return true;
+ }
+ bool conditional_add(uint32_t & group_count) {
+ if (group_count < _max_per_group) {
+ ++group_count;
+ add();
+ return true;
+ }
+ return false;
+ }
+};
+
+template <typename Fetcher>
+bool
+DiversityFilterT<Fetcher>::accepted(uint32_t docId) {
+ if (_total_count < _max_total) {
+ if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) {
+ typename Fetcher::ValueType group = _diversity.get(docId);
+ if (_seen.size() < _cutoff_max_groups) {
+ return conditional_add(_seen[group]);
+ } else {
+ auto found = _seen.find(group);
+ return (found == _seen.end()) ? add() : conditional_add(found->second);
+ }
+ } else if ( !_cutoff_strict) {
+ return add();
+ }
+ }
+ return false;
+}
+
+std::unique_ptr<DiversityFilter>
+DiversityFilter::create(const IAttributeVector &diversity_attr, size_t wanted_hits,
+ size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict)
+{
+ if (diversity_attr.hasEnum()) { // must handle enum first
+ FetchEnumFast fastEnum(diversity_attr);
+ if (fastEnum.valid()) {
+ return make_unique<DiversityFilterT<FetchEnumFast>> (fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ } else {
+ return make_unique<DiversityFilterT<FetchEnum>>(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ }
+ } else if (diversity_attr.isIntegerType()) {
+ using FetchInt32Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > >;
+ using FetchInt64Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > >;
+
+ FetchInt32Fast fastInt32(diversity_attr);
+ FetchInt64Fast fastInt64(diversity_attr);
+ if (fastInt32.valid()) {
+ return make_unique<DiversityFilterT<FetchInt32Fast>>(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ } else if (fastInt64.valid()) {
+ return make_unique<DiversityFilterT<FetchInt64Fast>>(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ } else {
+ return make_unique<DiversityFilterT<FetchInteger>>(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ }
+ } else if (diversity_attr.isFloatingPointType()) {
+ using FetchFloatFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > >;
+ using FetchDoubleFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > >;
+ FetchFloatFast fastFloat(diversity_attr);
+ FetchDoubleFast fastDouble(diversity_attr);
+ if (fastFloat.valid()) {
+ return make_unique<DiversityFilterT<FetchFloatFast>>(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ } else if (fastDouble.valid()) {
+ return make_unique<DiversityFilterT<FetchDoubleFast>>(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ } else {
+ return make_unique<DiversityFilterT<FetchFloat>>(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits);
+ }
+ }
+ return std::unique_ptr<DiversityFilter>();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.h b/searchlib/src/vespa/searchlib/attribute/diversity.h
index fe2874a65a1..da5df19f9e8 100644
--- a/searchlib/src/vespa/searchlib/attribute/diversity.h
+++ b/searchlib/src/vespa/searchlib/attribute/diversity.h
@@ -2,9 +2,8 @@
#pragma once
-#include "singleenumattribute.h"
-#include "singlenumericattribute.h"
-#include <vespa/vespalib/stllike/hash_map.h>
+#include <vespa/searchcommon/attribute/iattributevector.h>
+#include <vespa/searchlib/queryeval/idiversifier.h>
/**
* This file contains low-level code used to implement diversified
@@ -59,169 +58,67 @@ public:
bool has_next() const { return _lower != _upper; }
};
-template <typename T>
-struct FetchNumberFast {
- const T * const attr;
- typedef typename T::LoadedValueType ValueType;
- FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {}
- ValueType get(uint32_t docid) const { return attr->getFast(docid); }
- bool valid() const { return (attr != nullptr); }
-};
-
-struct FetchEnumFast {
- const SingleValueEnumAttributeBase * const attr;
- typedef uint32_t ValueType;
- FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {}
- ValueType get(uint32_t docid) const { return attr->getE(docid); }
- bool valid() const { return (attr != nullptr); }
-};
-
-struct FetchEnum {
- const IAttributeVector &attr;
- typedef uint32_t ValueType;
- FetchEnum(const IAttributeVector &attr_in) : attr(attr_in) {}
- ValueType get(uint32_t docid) const { return attr.getEnum(docid); }
-};
-
-struct FetchInteger {
- const IAttributeVector &attr;
- typedef int64_t ValueType;
- FetchInteger(const IAttributeVector &attr_in) : attr(attr_in) {}
- ValueType get(uint32_t docid) const { return attr.getInt(docid); }
-};
-
-struct FetchFloat {
- const IAttributeVector &attr;
- typedef double ValueType;
- FetchFloat(const IAttributeVector &attr_in) : attr(attr_in) {}
- ValueType get(uint32_t docid) const { return attr.getFloat(docid); }
+class DiversityFilter : public queryeval::IDiversifier {
+public:
+ DiversityFilter(size_t max_total) : _max_total(max_total) {}
+ size_t getMaxTotal() const { return _max_total; }
+ static std::unique_ptr<DiversityFilter>
+ create(const IAttributeVector &diversity_attr, size_t wanted_hits,
+ size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict);
+protected:
+ size_t _max_total;
};
-template <typename Fetcher, typename Result>
-class DiversityFilter {
+template <typename Result>
+class DiversityRecorder {
private:
- size_t _total_count;
- size_t _max_total;
- const Fetcher &_diversity;
- size_t _max_per_group;
- size_t _cutoff_max_groups;
- bool _cutoff_strict;
-
- typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity;
- Diversity _seen;
+ DiversityFilter & _filter;
Result &_result;
public:
- DiversityFilter(const Fetcher &diversity, size_t max_per_group,
- size_t cutoff_max_groups, bool cutoff_strict,
- Result &result, size_t max_total)
- : _total_count(0), _max_total(max_total), _diversity(diversity), _max_per_group(max_per_group),
- _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict),
- _seen(std::min(cutoff_max_groups, 10000ul)*3), _result(result)
+ DiversityRecorder(DiversityFilter & filter, Result &result)
+ : _filter(filter), _result(result)
{ }
+
template <typename Item>
void push_back(Item item) {
- if (_total_count < _max_total) {
- if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) {
- typename Fetcher::ValueType group = _diversity.get(item._key);
- if (_seen.size() < _cutoff_max_groups) {
- conditional_add(_seen[group], item);
- } else {
- auto found = _seen.find(group);
- if (found == _seen.end()) {
- add(item);
- } else {
- conditional_add(found->second, item);
- }
- }
- } else if ( !_cutoff_strict) {
- add(item);
- }
- }
- }
-private:
- template <typename Item>
- void add(Item item) {
- ++_total_count;
- _result.push_back(item);
- }
- template <typename Item>
- void conditional_add(uint32_t & group_count, Item item) {
- if (group_count < _max_per_group) {
- ++group_count;
- add(item);
+ if (_filter.accepted(item._key)) {
+ _result.push_back(item);
}
}
+
};
-template <typename DictRange, typename PostingStore, typename Fetcher, typename Result>
-void diversify_3(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits,
- const Fetcher &diversity, size_t max_per_group,
- size_t cutoff_max_groups, bool cutoff_strict,
+template <typename DictRange, typename PostingStore, typename Result>
+void diversify_2(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter,
Result &result, std::vector<size_t> &fragments)
{
+
+ DiversityRecorder<Result> recorder(filter, result);
DictRange range(range_in);
using DataType = typename PostingStore::DataType;
using KeyDataType = typename PostingStore::KeyDataType;
- DiversityFilter<Fetcher, Result> filter(diversity, max_per_group, cutoff_max_groups, cutoff_strict, result, wanted_hits);
- while (range.has_next() && (result.size() < wanted_hits)) {
+ while (range.has_next() && (result.size() < filter.getMaxTotal())) {
typename DictRange::Next dict_entry(range);
posting.foreach_frozen(dict_entry.get().getData(),
[&](uint32_t key, const DataType &data)
- { filter.push_back(KeyDataType(key, data)); });
+ { recorder.push_back(KeyDataType(key, data)); });
if (fragments.back() < result.size()) {
fragments.push_back(result.size());
}
}
}
-template <typename DictRange, typename PostingStore, typename Result>
-void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits,
- const IAttributeVector &diversity_attr, size_t max_per_group,
- size_t cutoff_max_groups, bool cutoff_strict,
- Result &result, std::vector<size_t> &fragments)
-{
- if (diversity_attr.hasEnum()) { // must handle enum first
- FetchEnumFast fastEnum(diversity_attr);
- if (fastEnum.valid()) {
- diversify_3(range_in, posting, wanted_hits, fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- } else {
- diversify_3(range_in, posting, wanted_hits, FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- }
- } else if (diversity_attr.isIntegerType()) {
- FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > > fastInt32(diversity_attr);
- FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > > fastInt64(diversity_attr);
- if (fastInt32.valid()) {
- diversify_3(range_in, posting, wanted_hits, fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- } else if (fastInt64.valid()) {
- diversify_3(range_in, posting, wanted_hits, fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- } else {
- diversify_3(range_in, posting, wanted_hits, FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- }
- } else if (diversity_attr.isFloatingPointType()) {
- FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > > fastFloat(diversity_attr);
- FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > > fastDouble(diversity_attr);
- if (fastFloat.valid()) {
- diversify_3(range_in, posting, wanted_hits, fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- } else if (fastDouble.valid()) {
- diversify_3(range_in, posting, wanted_hits, fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- } else {
- diversify_3(range_in, posting, wanted_hits, FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
- }
- }
-}
-
template <typename DictItr, typename PostingStore, typename Result>
void diversify(bool forward, const DictItr &lower, const DictItr &upper, const PostingStore &posting, size_t wanted_hits,
const IAttributeVector &diversity_attr, size_t max_per_group,
size_t cutoff_max_groups, bool cutoff_strict,
Result &array, std::vector<size_t> &fragments)
{
+ auto filter = DiversityFilter::create(diversity_attr, wanted_hits, max_per_group, cutoff_max_groups, cutoff_strict);
if (forward) {
- diversify_2(ForwardRange<DictItr>(lower, upper), posting, wanted_hits,
- diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments);
+ diversify_2(ForwardRange<DictItr>(lower, upper), posting, *filter, array, fragments);
} else {
- diversify_2(ReverseRange<DictItr>(lower, upper), posting, wanted_hits,
- diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments);
+ diversify_2(ReverseRange<DictItr>(lower, upper), posting, *filter, array, fragments);
}
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/idiversifier.h b/searchlib/src/vespa/searchlib/queryeval/idiversifier.h
new file mode 100644
index 00000000000..e77cb959eeb
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/queryeval/idiversifier.h
@@ -0,0 +1,16 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <cstdint>
+
+namespace search::queryeval {
+
+struct IDiversifier {
+ virtual ~IDiversifier() {}
+ /**
+ * Will tell if this document should be kept, and update state for further filtering.
+ */
+ virtual bool accepted(uint32_t docId) = 0;
+};
+}
diff --git a/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp b/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp
index fa8f465500e..1e0659c92e3 100644
--- a/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp
@@ -1,8 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/searchlib/queryeval/isourceselector.h>
+#include "isourceselector.h"
-namespace search {
-namespace queryeval {
+namespace search::queryeval {
ISourceSelector::ISourceSelector(Source defaultSource) :
_baseId(0),
@@ -12,5 +11,3 @@ ISourceSelector::ISourceSelector(Source defaultSource) :
}
}
-
-}
diff --git a/searchlib/src/vespa/searchlib/queryeval/isourceselector.h b/searchlib/src/vespa/searchlib/queryeval/isourceselector.h
index a3eac806558..88a3cb57a8a 100644
--- a/searchlib/src/vespa/searchlib/queryeval/isourceselector.h
+++ b/searchlib/src/vespa/searchlib/queryeval/isourceselector.h
@@ -2,11 +2,9 @@
#pragma once
-#include <stdint.h>
#include <vespa/searchlib/attribute/singlenumericattribute.h>
-namespace search {
-namespace queryeval {
+namespace search::queryeval {
typedef uint8_t Source;
@@ -91,6 +89,4 @@ private:
Source _defaultSource;
};
-} // namespace queryeval
-} // namespace search
-
+}