diff options
author | Henning Baldersheim <balder@oath.com> | 2018-07-26 13:38:32 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@oath.com> | 2018-07-26 13:38:32 +0200 |
commit | f2860f5825f39c51e83ce32657b6c24908a99d59 (patch) | |
tree | ef3bd4d0233f09078daa4bc66bc5710ef490529b /searchlib | |
parent | e78f728e723f15b2423c7d5556e4bf9bc47c2c6f (diff) |
Restructure for code reuse and hiding implementation.
Diffstat (limited to 'searchlib')
5 files changed, 150 insertions, 164 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.cpp b/searchlib/src/vespa/searchlib/attribute/diversity.cpp index 4c6fe054b12..65a84baa212 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.cpp +++ b/searchlib/src/vespa/searchlib/attribute/diversity.cpp @@ -1,7 +1,143 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "diversity.hpp" +#include "singleenumattribute.h" +#include "singlenumericattribute.h" +#include <vespa/vespalib/stllike/hash_map.h> +using std::make_unique; namespace search::attribute::diversity { +template <typename T> +struct FetchNumberFast { + const T * const attr; + typedef typename T::LoadedValueType ValueType; + FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {} + ValueType get(uint32_t docid) const { return attr->getFast(docid); } + bool valid() const { return (attr != nullptr); } +}; + +struct FetchEnumFast { + const SingleValueEnumAttributeBase * const attr; + typedef uint32_t ValueType; + FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {} + ValueType get(uint32_t docid) const { return attr->getE(docid); } + bool valid() const { return (attr != nullptr); } +}; + +struct FetchEnum { + const IAttributeVector *attr; + typedef uint32_t ValueType; + FetchEnum(const IAttributeVector & attr_in) : attr(&attr_in) {} + ValueType get(uint32_t docid) const { return attr->getEnum(docid); } +}; + +struct FetchInteger { + const IAttributeVector * attr; + typedef int64_t ValueType; + FetchInteger(const IAttributeVector & attr_in) : attr(&attr_in) {} + ValueType get(uint32_t docid) const { return attr->getInt(docid); } +}; + +struct FetchFloat { + const IAttributeVector * attr; + typedef double ValueType; + FetchFloat(const IAttributeVector & attr_in) : attr(&attr_in) {} + ValueType get(uint32_t docid) const { return attr->getFloat(docid); } +}; + +template <typename Fetcher> +class DiversityFilterT final : public DiversityFilter { +private: + size_t _total_count; + Fetcher _diversity; + size_t _max_per_group; + size_t _cutoff_max_groups; + bool _cutoff_strict; + + typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity; + Diversity _seen; +public: + DiversityFilterT(Fetcher diversity, size_t max_per_group, size_t cutoff_max_groups, + bool cutoff_strict, size_t max_total) + : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group), + _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict), + _seen(std::min(cutoff_max_groups, 10000ul)*3) + { } + + bool accepted(uint32_t docId) override; +private: + bool add() { + ++_total_count; + return true; + } + bool conditional_add(uint32_t & group_count) { + if (group_count < _max_per_group) { + ++group_count; + add(); + return true; + } + return false; + } +}; + +template <typename Fetcher> +bool +DiversityFilterT<Fetcher>::accepted(uint32_t docId) { + if (_total_count < _max_total) { + if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { + typename Fetcher::ValueType group = _diversity.get(docId); + if (_seen.size() < _cutoff_max_groups) { + return conditional_add(_seen[group]); + } else { + auto found = _seen.find(group); + return (found == _seen.end()) ? add() : conditional_add(found->second); + } + } else if ( !_cutoff_strict) { + return add(); + } + } + return false; +} + +std::unique_ptr<DiversityFilter> +DiversityFilter::create(const IAttributeVector &diversity_attr, size_t wanted_hits, + size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict) +{ + if (diversity_attr.hasEnum()) { // must handle enum first + FetchEnumFast fastEnum(diversity_attr); + if (fastEnum.valid()) { + return make_unique<DiversityFilterT<FetchEnumFast>> (fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else { + return make_unique<DiversityFilterT<FetchEnum>>(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } + } else if (diversity_attr.isIntegerType()) { + using FetchInt32Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > >; + using FetchInt64Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > >; + + FetchInt32Fast fastInt32(diversity_attr); + FetchInt64Fast fastInt64(diversity_attr); + if (fastInt32.valid()) { + return make_unique<DiversityFilterT<FetchInt32Fast>>(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else if (fastInt64.valid()) { + return make_unique<DiversityFilterT<FetchInt64Fast>>(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else { + return make_unique<DiversityFilterT<FetchInteger>>(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } + } else if (diversity_attr.isFloatingPointType()) { + using FetchFloatFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > >; + using FetchDoubleFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > >; + FetchFloatFast fastFloat(diversity_attr); + FetchDoubleFast fastDouble(diversity_attr); + if (fastFloat.valid()) { + return make_unique<DiversityFilterT<FetchFloatFast>>(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else if (fastDouble.valid()) { + return make_unique<DiversityFilterT<FetchDoubleFast>>(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } else { + return make_unique<DiversityFilterT<FetchFloat>>(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); + } + } + return std::unique_ptr<DiversityFilter>(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.h b/searchlib/src/vespa/searchlib/attribute/diversity.h index 90b499252c2..da5df19f9e8 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.h +++ b/searchlib/src/vespa/searchlib/attribute/diversity.h @@ -2,9 +2,8 @@ #pragma once -#include "singleenumattribute.h" -#include "singlenumericattribute.h" -#include <vespa/vespalib/stllike/hash_map.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/searchlib/queryeval/idiversifier.h> /** * This file contains low-level code used to implement diversified @@ -59,89 +58,17 @@ public: bool has_next() const { return _lower != _upper; } }; -template <typename T> -struct FetchNumberFast { - const T * const attr; - typedef typename T::LoadedValueType ValueType; - FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {} - ValueType get(uint32_t docid) const { return attr->getFast(docid); } - bool valid() const { return (attr != nullptr); } -}; - -struct FetchEnumFast { - const SingleValueEnumAttributeBase * const attr; - typedef uint32_t ValueType; - FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {} - ValueType get(uint32_t docid) const { return attr->getE(docid); } - bool valid() const { return (attr != nullptr); } -}; - -struct FetchEnum { - const IAttributeVector &attr; - typedef uint32_t ValueType; - FetchEnum(const IAttributeVector &attr_in) : attr(attr_in) {} - ValueType get(uint32_t docid) const { return attr.getEnum(docid); } -}; - -struct FetchInteger { - const IAttributeVector &attr; - typedef int64_t ValueType; - FetchInteger(const IAttributeVector &attr_in) : attr(attr_in) {} - ValueType get(uint32_t docid) const { return attr.getInt(docid); } -}; - -struct FetchFloat { - const IAttributeVector &attr; - typedef double ValueType; - FetchFloat(const IAttributeVector &attr_in) : attr(attr_in) {} - ValueType get(uint32_t docid) const { return attr.getFloat(docid); } -}; - -class DiversityFilter { +class DiversityFilter : public queryeval::IDiversifier { public: DiversityFilter(size_t max_total) : _max_total(max_total) {} - virtual ~DiversityFilter() {} - virtual bool accepted(uint32_t docId) = 0; size_t getMaxTotal() const { return _max_total; } + static std::unique_ptr<DiversityFilter> + create(const IAttributeVector &diversity_attr, size_t wanted_hits, + size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict); protected: size_t _max_total; }; -template <typename Fetcher> -class DiversityFilterT final : public DiversityFilter { -private: - size_t _total_count; - const Fetcher &_diversity; - size_t _max_per_group; - size_t _cutoff_max_groups; - bool _cutoff_strict; - - typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity; - Diversity _seen; -public: - DiversityFilterT(const Fetcher &diversity, size_t max_per_group, - size_t cutoff_max_groups, bool cutoff_strict, size_t max_total) - : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group), - _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict), - _seen(std::min(cutoff_max_groups, 10000ul)*3) - { } - - bool accepted(uint32_t docId) override; -private: - bool add() { - ++_total_count; - return true; - } - bool conditional_add(uint32_t & group_count) { - if (group_count < _max_per_group) { - ++group_count; - add(); - return true; - } - return false; - } -}; - template <typename Result> class DiversityRecorder { private: @@ -162,7 +89,7 @@ public: }; template <typename DictRange, typename PostingStore, typename Result> -void diversify_3(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter, +void diversify_2(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter, Result &result, std::vector<size_t> &fragments) { @@ -181,67 +108,17 @@ void diversify_3(const DictRange &range_in, const PostingStore &posting, Diversi } } -template <typename DictRange, typename PostingStore, typename Result> -void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits, - const IAttributeVector &diversity_attr, size_t max_per_group, - size_t cutoff_max_groups, bool cutoff_strict, - Result &result, std::vector<size_t> &fragments) -{ - if (diversity_attr.hasEnum()) { // must handle enum first - FetchEnumFast fastEnum(diversity_attr); - if (fastEnum.valid()) { - DiversityFilterT<FetchEnumFast> filter(fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } else { - DiversityFilterT<FetchEnum> filter(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } - } else if (diversity_attr.isIntegerType()) { - using FetchInt32Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > >; - using FetchInt64Fast = FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > >; - - FetchInt32Fast fastInt32(diversity_attr); - FetchInt64Fast fastInt64(diversity_attr); - if (fastInt32.valid()) { - DiversityFilterT<FetchInt32Fast> filter(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } else if (fastInt64.valid()) { - DiversityFilterT<FetchInt64Fast> filter(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } else { - DiversityFilterT<FetchInteger> filter(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } - } else if (diversity_attr.isFloatingPointType()) { - using FetchFloatFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > >; - using FetchDoubleFast = FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > >; - FetchFloatFast fastFloat(diversity_attr); - FetchDoubleFast fastDouble(diversity_attr); - if (fastFloat.valid()) { - DiversityFilterT<FetchFloatFast> filter(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } else if (fastDouble.valid()) { - DiversityFilterT<FetchDoubleFast> filter(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } else { - DiversityFilterT<FetchFloat> filter(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); - diversify_3(range_in, posting, filter, result, fragments); - } - } -} - template <typename DictItr, typename PostingStore, typename Result> void diversify(bool forward, const DictItr &lower, const DictItr &upper, const PostingStore &posting, size_t wanted_hits, const IAttributeVector &diversity_attr, size_t max_per_group, size_t cutoff_max_groups, bool cutoff_strict, Result &array, std::vector<size_t> &fragments) { + auto filter = DiversityFilter::create(diversity_attr, wanted_hits, max_per_group, cutoff_max_groups, cutoff_strict); if (forward) { - diversify_2(ForwardRange<DictItr>(lower, upper), posting, wanted_hits, - diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments); + diversify_2(ForwardRange<DictItr>(lower, upper), posting, *filter, array, fragments); } else { - diversify_2(ReverseRange<DictItr>(lower, upper), posting, wanted_hits, - diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments); + diversify_2(ReverseRange<DictItr>(lower, upper), posting, *filter, array, fragments); } } diff --git a/searchlib/src/vespa/searchlib/attribute/diversity.hpp b/searchlib/src/vespa/searchlib/attribute/diversity.hpp index 52e6e5ede04..698f482dec1 100644 --- a/searchlib/src/vespa/searchlib/attribute/diversity.hpp +++ b/searchlib/src/vespa/searchlib/attribute/diversity.hpp @@ -31,24 +31,4 @@ ReverseRange<ITR>::ReverseRange(const ITR &lower, const ITR &upper) template <typename ITR> ReverseRange<ITR>::~ReverseRange() = default; -template <typename Fetcher> -bool -DiversityFilterT<Fetcher>::accepted(uint32_t docId) { - if (_total_count < _max_total) { - if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { - typename Fetcher::ValueType group = _diversity.get(docId); - if (_seen.size() < _cutoff_max_groups) { - return conditional_add(_seen[group]); - } else { - auto found = _seen.find(group); - return (found == _seen.end()) ? add() : conditional_add(found->second); - } - } else if ( !_cutoff_strict) { - return add(); - } - } - return false; -} - - } diff --git a/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp b/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp index fa8f465500e..1e0659c92e3 100644 --- a/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/isourceselector.cpp @@ -1,8 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/searchlib/queryeval/isourceselector.h> +#include "isourceselector.h" -namespace search { -namespace queryeval { +namespace search::queryeval { ISourceSelector::ISourceSelector(Source defaultSource) : _baseId(0), @@ -12,5 +11,3 @@ ISourceSelector::ISourceSelector(Source defaultSource) : } } - -} diff --git a/searchlib/src/vespa/searchlib/queryeval/isourceselector.h b/searchlib/src/vespa/searchlib/queryeval/isourceselector.h index a3eac806558..88a3cb57a8a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/isourceselector.h +++ b/searchlib/src/vespa/searchlib/queryeval/isourceselector.h @@ -2,11 +2,9 @@ #pragma once -#include <stdint.h> #include <vespa/searchlib/attribute/singlenumericattribute.h> -namespace search { -namespace queryeval { +namespace search::queryeval { typedef uint8_t Source; @@ -91,6 +89,4 @@ private: Source _defaultSource; }; -} // namespace queryeval -} // namespace search - +} |