diff options
5 files changed, 185 insertions, 143 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.cpp b/searchlib/src/vespa/searchlib/attribute/diversity.cpp index 4c6fe054b12..65a84baa212 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.cpp +++ b/searchlib/src/vespa/searchlib/attribute/diversity.cpp @@ -1,7 +1,143 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "diversity.hpp" +#include "singleenumattribute.h" +#include "singlenumericattribute.h" +#include <vespa/vespalib/stllike/hash_map.h> +using std::make_unique; namespace search::attribute::diversity { +template <typename T> +struct FetchNumberFast { + const T * const attr; + typedef typename T::LoadedValueType ValueType; + FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {} + ValueType get(uint32_t docid) const { return attr->getFast(docid); } + bool valid() const { return (attr != nullptr); } +}; + +struct FetchEnumFast { + const SingleValueEnumAttributeBase * const attr; + typedef uint32_t ValueType; + FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {} + ValueType get(uint32_t docid) const { return attr->getE(docid); } + bool valid() const { return (attr != nullptr); } +}; + +struct FetchEnum { + const IAttributeVector *attr; + typedef uint32_t ValueType; + FetchEnum(const IAttributeVector & attr_in) : attr(&attr_in) {} + ValueType get(uint32_t docid) const { return attr->getEnum(docid); } +}; + +struct FetchInteger { + const IAttributeVector * attr; + typedef int64_t ValueType; + FetchInteger(const IAttributeVector & attr_in) : attr(&attr_in) {} + ValueType get(uint32_t docid) const { return attr->getInt(docid); } +}; + +struct FetchFloat { + const IAttributeVector * attr; + typedef double ValueType; + FetchFloat(const IAttributeVector & attr_in) : attr(&attr_in) {} + ValueType get(uint32_t docid) const { return attr->getFloat(docid); } +}; + +template <typename Fetcher> +class DiversityFilterT final : public DiversityFilter { +private: + size_t _total_count; + Fetcher _diversity; + size_t _max_per_group; + size_t _cutoff_max_groups; + bool _cutoff_strict; + + typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity; + Diversity _seen; +public: + DiversityFilterT(Fetcher diversity, size_t max_per_group, size_t cutoff_max_groups, + bool cutoff_strict, size_t max_total) + : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group), + _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict), + _seen(std::min(cutoff_max_groups, 10000ul)*3) + { } + + bool accepted(uint32_t docId) override; +private: + bool add() { + ++_total_count; + return true; + } + bool conditional_add(uint32_t & group_count) { + if (group_count < _max_per_group) { + ++group_count; + add(); + return true; + } + return false; + } +}; + +template <typename Fetcher> +bool +DiversityFilterT<Fetcher>::accepted(uint32_t docId) { + if (_total_count < _max_total) { + if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { + typename Fetcher::ValueType group = _diversity.get(docId); + if (_seen.size() < _cutoff_max_groups) { + return conditional_add(_seen[group]); + } else { + auto found = _seen.find(group); + return (found == _seen.end()) ? add() : conditional_add(found->second); + } + } else if ( !_cutoff_strict) { + return add(); + } + } + return false; +} + +std::unique_ptr<DiversityFilter> +DiversityFilter::create(const IAttributeVector &diversity_attr, size_t wanted_hits, + size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict) +{ + if (diversity_attr.hasEnum()) { // must handle enum first + FetchEnumFast fastEnum(diversity_attr); + if (fastEnum.valid()) { + return make_unique<DiversityFilterT<FetchEnumFast>> (fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else { + return make_unique<DiversityFilterT<FetchEnum>>(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } + } else if (diversity_attr.isIntegerType()) { + using FetchInt32Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > >; + using FetchInt64Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > >; + + FetchInt32Fast fastInt32(diversity_attr); + FetchInt64Fast fastInt64(diversity_attr); + if (fastInt32.valid()) { + return make_unique<DiversityFilterT<FetchInt32Fast>>(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else if (fastInt64.valid()) { + return make_unique<DiversityFilterT<FetchInt64Fast>>(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else { + return make_unique<DiversityFilterT<FetchInteger>>(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } + } else if (diversity_attr.isFloatingPointType()) { + using FetchFloatFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > >; + using FetchDoubleFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > >; + FetchFloatFast fastFloat(diversity_attr); + FetchDoubleFast fastDouble(diversity_attr); + if (fastFloat.valid()) { + return make_unique<DiversityFilterT<FetchFloatFast>>(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else if (fastDouble.valid()) { + return make_unique<DiversityFilterT<FetchDoubleFast>>(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else { + return make_unique<DiversityFilterT<FetchFloat>>(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } + } + return std::unique_ptr<DiversityFilter>(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.h b/searchlib/src/vespa/searchlib/attribute/diversity.h index fe2874a65a1..da5df19f9e8 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.h +++ b/searchlib/src/vespa/searchlib/attribute/diversity.h @@ -2,9 +2,8 @@ #pragma once -#include "singleenumattribute.h" -#include "singlenumericattribute.h" -#include <vespa/vespalib/stllike/hash_map.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/searchlib/queryeval/idiversifier.h> /** * This file contains low-level code used to implement diversified @@ -59,169 +58,67 @@ public: bool has_next() const { return _lower != _upper; } }; -template <typename T> -struct FetchNumberFast { - const T * const attr; - typedef typename T::LoadedValueType ValueType; - FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {} - ValueType get(uint32_t docid) const { return attr->getFast(docid); } - bool valid() const { return (attr != nullptr); } -}; - -struct FetchEnumFast { - const SingleValueEnumAttributeBase * const attr; - typedef uint32_t ValueType; - FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {} - ValueType get(uint32_t docid) const { return attr->getE(docid); } - bool valid() const { return (attr != nullptr); } -}; - -struct FetchEnum { - const IAttributeVector &attr; - typedef uint32_t ValueType; - FetchEnum(const IAttributeVector &attr_in) : attr(attr_in) {} - ValueType get(uint32_t docid) const { return attr.getEnum(docid); } -}; - -struct FetchInteger { - const IAttributeVector &attr; - typedef int64_t ValueType; - FetchInteger(const IAttributeVector &attr_in) : attr(attr_in) {} - ValueType get(uint32_t docid) const { return attr.getInt(docid); } -}; - -struct FetchFloat { - const IAttributeVector &attr; - typedef double ValueType; - FetchFloat(const IAttributeVector &attr_in) : attr(attr_in) {} - ValueType get(uint32_t docid) const { return attr.getFloat(docid); } +class DiversityFilter : public queryeval::IDiversifier { +public: + DiversityFilter(size_t max_total) : _max_total(max_total) {} + size_t getMaxTotal() const { return _max_total; } + static std::unique_ptr<DiversityFilter> + create(const IAttributeVector &diversity_attr, size_t wanted_hits, + size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict); +protected: + size_t _max_total; }; -template <typename Fetcher, typename Result> -class DiversityFilter { +template <typename Result> +class DiversityRecorder { private: - size_t _total_count; - size_t _max_total; - const Fetcher &_diversity; - size_t _max_per_group; - size_t _cutoff_max_groups; - bool _cutoff_strict; - - typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity; - Diversity _seen; + DiversityFilter & _filter; Result &_result; public: - DiversityFilter(const Fetcher &diversity, size_t max_per_group, - size_t cutoff_max_groups, bool cutoff_strict, - Result &result, size_t max_total) - : _total_count(0), _max_total(max_total), _diversity(diversity), _max_per_group(max_per_group), - _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict), - _seen(std::min(cutoff_max_groups, 10000ul)*3), _result(result) + DiversityRecorder(DiversityFilter & filter, Result &result) + : _filter(filter), _result(result) { } + template <typename Item> void push_back(Item item) { - if (_total_count < _max_total) { - if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { - typename Fetcher::ValueType group = _diversity.get(item._key); - if (_seen.size() < _cutoff_max_groups) { - conditional_add(_seen[group], item); - } else { - auto found = _seen.find(group); - if (found == _seen.end()) { - add(item); - } else { - conditional_add(found->second, item); - } - } - } else if ( !_cutoff_strict) { - add(item); - } - } - } -private: - template <typename Item> - void add(Item item) { - ++_total_count; - _result.push_back(item); - } - template <typename Item> - void conditional_add(uint32_t & group_count, Item item) { - if (group_count < _max_per_group) { - ++group_count; - add(item); + if (_filter.accepted(item._key)) { + _result.push_back(item); } } + }; -template <typename DictRange, typename PostingStore, typename Fetcher, typename Result> -void diversify_3(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits, - const Fetcher &diversity, size_t max_per_group, - size_t cutoff_max_groups, bool cutoff_strict, +template <typename DictRange, typename PostingStore, typename Result> +void diversify_2(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter, Result &result, std::vector<size_t> &fragments) { + + DiversityRecorder<Result> recorder(filter, result); DictRange range(range_in); using DataType = typename PostingStore::DataType; using KeyDataType = typename PostingStore::KeyDataType; - DiversityFilter<Fetcher, Result> filter(diversity, max_per_group, cutoff_max_groups, cutoff_strict, result, wanted_hits); - while (range.has_next() && (result.size() < wanted_hits)) { + while (range.has_next() && (result.size() < filter.getMaxTotal())) { typename DictRange::Next dict_entry(range); posting.foreach_frozen(dict_entry.get().getData(), [&](uint32_t key, const DataType &data) - { filter.push_back(KeyDataType(key, data)); }); + { recorder.push_back(KeyDataType(key, data)); }); if (fragments.back() < result.size()) { fragments.push_back(result.size()); } } } -template <typename DictRange, typename PostingStore, typename Result> -void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits, - const IAttributeVector &diversity_attr, size_t max_per_group, - size_t cutoff_max_groups, bool cutoff_strict, - Result &result, std::vector<size_t> &fragments) -{ - if (diversity_attr.hasEnum()) { // must handle enum first - FetchEnumFast fastEnum(diversity_attr); - if (fastEnum.valid()) { - diversify_3(range_in, posting, wanted_hits, fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } else { - diversify_3(range_in, posting, wanted_hits, FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } - } else if (diversity_attr.isIntegerType()) { - FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > > fastInt32(diversity_attr); - FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > > fastInt64(diversity_attr); - if (fastInt32.valid()) { - diversify_3(range_in, posting, wanted_hits, fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } else if (fastInt64.valid()) { - diversify_3(range_in, posting, wanted_hits, fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } else { - diversify_3(range_in, posting, wanted_hits, FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } - } else if (diversity_attr.isFloatingPointType()) { - FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > > fastFloat(diversity_attr); - FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > > fastDouble(diversity_attr); - if (fastFloat.valid()) { - diversify_3(range_in, posting, wanted_hits, fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } else if (fastDouble.valid()) { - diversify_3(range_in, posting, wanted_hits, fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } else { - diversify_3(range_in, posting, wanted_hits, FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments); - } - } -} - template <typename DictItr, typename PostingStore, typename Result> void diversify(bool forward, const DictItr &lower, const DictItr &upper, const PostingStore &posting, size_t wanted_hits, const IAttributeVector &diversity_attr, size_t max_per_group, size_t cutoff_max_groups, bool cutoff_strict, Result &array, std::vector<size_t> &fragments) { + auto filter = DiversityFilter::create(diversity_attr, wanted_hits, max_per_group, cutoff_max_groups, cutoff_strict); if (forward) { - diversify_2(ForwardRange<DictItr>(lower, upper), posting, wanted_hits, - diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments); + diversify_2(ForwardRange<DictItr>(lower, upper), posting, *filter, array, fragments); } else { - diversify_2(ReverseRange<DictItr>(lower, upper), posting, wanted_hits, - diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments); + diversify_2(ReverseRange<DictItr>(lower, upper), posting, *filter, array, fragments); } } diff --git a/searchlib/src/vespa/searchlib/queryeval/idiversifier.h b/searchlib/src/vespa/searchlib/queryeval/idiversifier.h new file mode 100644 index 00000000000..e77cb959eeb --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/idiversifier.h @@ -0,0 +1,16 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <cstdint> + +namespace search::queryeval { + +struct IDiversifier { + virtual ~IDiversifier() {} + /** + * Will tell if this document should be kept, and update state for further filtering. + */ + virtual bool accepted(uint32_t docId) = 0; +}; +} diff --git a/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp b/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp index fa8f465500e..1e0659c92e3 100644 --- a/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp @@ -1,8 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/searchlib/queryeval/isourceselector.h> +#include "isourceselector.h" -namespace search { -namespace queryeval { +namespace search::queryeval { ISourceSelector::ISourceSelector(Source defaultSource) : _baseId(0), @@ -12,5 +11,3 @@ ISourceSelector::ISourceSelector(Source defaultSource) : } } - -} diff --git a/searchlib/src/vespa/searchlib/queryeval/isourceselector.h b/searchlib/src/vespa/searchlib/queryeval/isourceselector.h index a3eac806558..88a3cb57a8a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/isourceselector.h +++ b/searchlib/src/vespa/searchlib/queryeval/isourceselector.h @@ -2,11 +2,9 @@ #pragma once -#include <stdint.h> #include <vespa/searchlib/attribute/singlenumericattribute.h> -namespace search { -namespace queryeval { +namespace search::queryeval { typedef uint8_t Source; @@ -91,6 +89,4 @@ private: Source _defaultSource; }; -} // namespace queryeval -} // namespace search - +} |