// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "diversity.hpp" #include "singlenumericattribute.h" #include using std::make_unique; namespace search::attribute::diversity { template struct FetchNumberFast { const T * const attr; using ValueType = typename T::LoadedValueType; FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast(&attr_in)) {} ValueType get(uint32_t docid) const { return attr->getFast(docid); } bool valid() const { return (attr != nullptr); } }; struct FetchEnumFast { IAttributeVector::EnumRefs enumRefs; using ValueType = uint32_t; FetchEnumFast(const IAttributeVector &attr) : enumRefs(attr.make_enum_read_view()) {} ValueType get(uint32_t docid) const { return enumRefs[docid].load_relaxed().ref(); } bool valid() const { return ! enumRefs.empty(); } }; struct FetchEnum { const IAttributeVector *attr; using ValueType = uint32_t; FetchEnum(const IAttributeVector & attr_in) : attr(&attr_in) {} ValueType get(uint32_t docid) const { return attr->getEnum(docid); } }; struct FetchInteger { const IAttributeVector * attr; using ValueType = int64_t; FetchInteger(const IAttributeVector & attr_in) : attr(&attr_in) {} ValueType get(uint32_t docid) const { return attr->getInt(docid); } }; struct FetchFloat { const IAttributeVector * attr; using ValueType = double; FetchFloat(const IAttributeVector & attr_in) : attr(&attr_in) {} ValueType get(uint32_t docid) const { return attr->getFloat(docid); } }; template class DiversityFilterT final : public DiversityFilter { private: size_t _total_count; Fetcher _diversity; size_t _max_per_group; size_t _cutoff_max_groups; bool _cutoff_strict; using Diversity = vespalib::hash_map; Diversity _seen; public: DiversityFilterT(Fetcher diversity, size_t max_per_group, size_t cutoff_max_groups, bool cutoff_strict, size_t max_total) : DiversityFilter(max_total), _total_count(0), _diversity(diversity), _max_per_group(max_per_group), _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict), _seen(std::min(cutoff_max_groups, 10000ul)*3) { } bool accepted(uint32_t docId) override; private: bool add() { ++_total_count; return true; } bool conditional_add(uint32_t & group_count) { if (group_count < _max_per_group) { ++group_count; add(); return true; } return false; } }; template bool DiversityFilterT::accepted(uint32_t docId) { if (_total_count < _max_total) { if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) { typename Fetcher::ValueType group = _diversity.get(docId); if (_seen.size() < _cutoff_max_groups) { return conditional_add(_seen[group]); } else { auto found = _seen.find(group); return (found == _seen.end()) ? add() : conditional_add(found->second); } } else if ( !_cutoff_strict) { return add(); } } return false; } std::unique_ptr DiversityFilter::create(const IAttributeVector &diversity_attr, size_t wanted_hits, size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict) { if (diversity_attr.hasEnum()) { // must handle enum first FetchEnumFast fastEnum(diversity_attr); if (fastEnum.valid()) { return make_unique> (fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } else { return make_unique>(FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } } else if (diversity_attr.isIntegerType()) { using FetchInt32Fast = FetchNumberFast > >; using FetchInt64Fast = FetchNumberFast > >; FetchInt32Fast fastInt32(diversity_attr); FetchInt64Fast fastInt64(diversity_attr); if (fastInt32.valid()) { return make_unique>(fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } else if (fastInt64.valid()) { return make_unique>(fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } else { return make_unique>(FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } } else if (diversity_attr.isFloatingPointType()) { using FetchFloatFast = FetchNumberFast > >; using FetchDoubleFast = FetchNumberFast > >; FetchFloatFast fastFloat(diversity_attr); FetchDoubleFast fastDouble(diversity_attr); if (fastFloat.valid()) { return make_unique>(fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } else if (fastDouble.valid()) { return make_unique>(fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } else { return make_unique>(FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, wanted_hits); } } return std::unique_ptr(); } }