diff options
author | Tor Egge <Tor.Egge@online.no> | 2024-05-31 15:33:28 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2024-05-31 15:33:28 +0200 |
commit | bc8cf5d97eb9b34a1123a65dc52d3921c809d1fb (patch) | |
tree | 6cb0dfb6be2d377bb7f06f2b08bccdd82321319d | |
parent | a47916388fc73d3d3fd2af66e882857c1932bd9c (diff) |
Implement second phase rank drop limit for hit collector.
3 files changed, 214 insertions, 20 deletions
diff --git a/searchlib/src/tests/hitcollector/hitcollector_test.cpp b/searchlib/src/tests/hitcollector/hitcollector_test.cpp index 3ac5f419228..12e06b0ff32 100644 --- a/searchlib/src/tests/hitcollector/hitcollector_test.cpp +++ b/searchlib/src/tests/hitcollector/hitcollector_test.cpp @@ -13,6 +13,8 @@ using namespace search::fef; using namespace search::queryeval; using ScoreMap = std::map<uint32_t, feature_t>; +using DocidVector = std::vector<uint32_t>; +using RankedHitVector = std::vector<RankedHit>; using Ranges = std::pair<Scores, Scores>; @@ -574,4 +576,75 @@ TEST(HitCollectorTest, require_that_hits_can_be_added_out_of_order_only_after_pa checkResult(*rs, nullptr); } +struct RankDropFixture { + uint32_t _docid_limit; + HitCollector _hc; + std::vector<uint32_t> _dropped; + RankDropFixture(uint32_t docid_limit, uint32_t max_hits_size) + : _docid_limit(docid_limit), + _hc(docid_limit, max_hits_size) + { + } + void add(std::vector<RankedHit> hits) { + for (const auto& hit : hits) { + _hc.addHit(hit.getDocId(), hit.getRank()); + } + } + void rerank(ScoreMap score_map, size_t count) { + PredefinedScorer scorer(score_map); + EXPECT_EQ(count, do_reRank(scorer, _hc, count)); + } + std::unique_ptr<BitVector> make_bv(DocidVector docids) { + auto bv = BitVector::create(_docid_limit); + for (auto& docid : docids) { + bv->setBit(docid); + } + return bv; + } + + void setup() { + add({{5, 1100},{10, 1200},{11, 1300},{12, 1400},{14, 500},{15, 900},{16,1000}}); + rerank({{11,14},{12,13}}, 2); + } + void check_result(std::optional<double> rank_drop_limit, RankedHitVector exp_array, + std::unique_ptr<BitVector> exp_bv, DocidVector exp_dropped) { + auto rs = _hc.get_result_set(rank_drop_limit, &_dropped); + checkResult(*rs, exp_array); + checkResult(*rs, exp_bv.get()); + EXPECT_EQ(exp_dropped, _dropped); + } +}; + +TEST(HitCollectorTest, require_that_second_phase_rank_drop_limit_is_enforced) +{ + RankDropFixture f(10000, 10); + f.setup(); + f.check_result(9.0, {{5,11},{10,12},{11,14},{12,13},{16,10}}, + {}, {14, 15}); +} + +TEST(HitCollectorTest, require_that_docid_vector_is_used) +{ + RankDropFixture f(10000, 4); + f.setup(); + f.check_result(13.0, {{11,14}}, + {}, {5,10,12,14,15,16}); +} + +TEST(HitCollectorTest, require_that_bitvector_is_not_dropped_without_rank_drop_limit) +{ + RankDropFixture f(20, 4); + f.setup(); + f.check_result(std::nullopt, {{5,11},{10,12},{11,14},{12,13}}, + f.make_bv({5,10,11,12,14,15,16}), {}); +} + +TEST(HitCollectorTest, require_that_bitvector_is_dropped_with_rank_drop_limit) +{ + RankDropFixture f(20, 4); + f.setup(); + f.check_result(9.0, {{5,11},{10,12},{11,14},{12,13}}, + {}, {14,15,16}); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp index 698593cfd8e..01587ef485a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp @@ -219,57 +219,150 @@ public: } }; -template <typename Rescorer> +class SimpleHitAdder { +protected: + ResultSet& _rs; +public: + SimpleHitAdder(ResultSet& rs) + : _rs(rs) + { + } + void add(uint32_t docid, double rank_value) { + _rs.push_back({docid, rank_value}); + } +}; + +class ConditionalHitAdder : public SimpleHitAdder { +protected: + double _second_phase_rank_drop_limit; +public: + ConditionalHitAdder(ResultSet& rs, double second_phase_rank_drop_limit) + : SimpleHitAdder(rs), + _second_phase_rank_drop_limit(second_phase_rank_drop_limit) + { + } + void add(uint32_t docid, double rank_value) { + if (rank_value > _second_phase_rank_drop_limit) { + _rs.push_back({docid, rank_value}); + } + } +}; + +class TrackingConditionalHitAdder : public ConditionalHitAdder { + std::vector<uint32_t>& _dropped; +public: + TrackingConditionalHitAdder(ResultSet& rs, double second_phase_rank_drop_limit, std::vector<uint32_t>& dropped) + : ConditionalHitAdder(rs, second_phase_rank_drop_limit), + _dropped(dropped) + { + } + void add(uint32_t docid, double rank_value) { + if (rank_value > _second_phase_rank_drop_limit) { + _rs.push_back({docid, rank_value}); + } else { + _dropped.emplace_back(docid); + } + } +}; + +template <typename HitAdder, typename Rescorer> void -add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer) +add_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer) { for (auto& hit : hits) { - rs.push_back({hit.first, rescorer.rescore(hit.first, hit.second)}); + hit_adder.add(hit.first, rescorer.rescore(hit.first, hit.second)); } } -template <typename Rescorer> +template <typename HitAdder, typename Rescorer> void -add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer) +add_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer) { if (reranked_hits.empty()) { - add_rescored_hits(rs, hits, rescorer); + add_rescored_hits(hit_adder, hits, rescorer); } else { - add_rescored_hits(rs, hits, RerankRescorer(rescorer, reranked_hits)); + add_rescored_hits(hit_adder, hits, RerankRescorer(rescorer, reranked_hits)); } } template <typename Rescorer> void -mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer) +add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped, Rescorer rescorer) +{ + if (second_phase_rank_drop_limit.has_value()) { + if (dropped != nullptr) { + add_rescored_hits(TrackingConditionalHitAdder(rs, second_phase_rank_drop_limit.value(), *dropped), hits, reranked_hits, rescorer); + } else { + add_rescored_hits(ConditionalHitAdder(rs, second_phase_rank_drop_limit.value()), hits, reranked_hits, rescorer); + } + } else { + add_rescored_hits(SimpleHitAdder(rs), hits, reranked_hits, rescorer); + } +} + +template <typename HitAdder, typename Rescorer> +void +mixin_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer) { auto hits_cur = hits.begin(); auto hits_end = hits.end(); for (auto docid : docids) { if (hits_cur != hits_end && docid == hits_cur->first) { - rs.push_back({docid, rescorer.rescore(docid, hits_cur->second)}); + hit_adder.add(docid, rescorer.rescore(docid, hits_cur->second)); ++hits_cur; } else { - rs.push_back({docid, default_value}); + hit_adder.add(docid, default_value); } } } -template <typename Rescorer> +template <typename HitAdder, typename Rescorer> void -mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer) +mixin_rescored_hits(HitAdder hit_adder, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer) { if (reranked_hits.empty()) { - mixin_rescored_hits(rs, hits, docids, default_value, rescorer); + mixin_rescored_hits(hit_adder, hits, docids, default_value, rescorer); } else { - mixin_rescored_hits(rs, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits)); + mixin_rescored_hits(hit_adder, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits)); + } +} + +template <typename Rescorer> +void +mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped, Rescorer rescorer) +{ + if (second_phase_rank_drop_limit.has_value()) { + if (dropped != nullptr) { + mixin_rescored_hits(TrackingConditionalHitAdder(rs, second_phase_rank_drop_limit.value(), *dropped), hits, docids, default_value, reranked_hits, rescorer); + } else { + mixin_rescored_hits(ConditionalHitAdder(rs, second_phase_rank_drop_limit.value()), hits, docids, default_value, reranked_hits, rescorer); + } + } else { + mixin_rescored_hits(SimpleHitAdder(rs), hits, docids, default_value, reranked_hits, rescorer); + } +} + +void +add_bitvector_to_dropped(std::vector<uint32_t>& dropped, vespalib::ConstArrayRef<RankedHit> hits, const BitVector& bv) +{ + auto hits_cur = hits.begin(); + auto hits_end = hits.end(); + auto docid = bv.getFirstTrueBit(); + auto docid_limit = bv.size(); + while (docid < docid_limit) { + if (hits_cur != hits_end && hits_cur->getDocId() == docid) { + ++hits_cur; + } else { + dropped.emplace_back(docid); + } + docid = bv.getNextTrueBit(docid + 1); } } } std::unique_ptr<ResultSet> -HitCollector::getResultSet() +HitCollector::get_result_set(std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped) { /* * Use default_rank_value (i.e. -HUGE_VAL) when hit collector saves @@ -280,16 +373,26 @@ HitCollector::getResultSet() bool needReScore = FirstPhaseRescorer::need_rescore(_ranges); FirstPhaseRescorer rescorer(_ranges); + if (dropped != nullptr) { + dropped->clear(); + } + // destroys the heap property or score sort order sortHitsByDocId(); auto rs = std::make_unique<ResultSet>(); - if ( ! _collector->isDocIdCollector() ) { + if ( ! _collector->isDocIdCollector() || + (second_phase_rank_drop_limit.has_value() && + (_bitVector || dropped == nullptr))) { rs->allocArray(_hits.size()); + auto* dropped_or_null = dropped; + if (second_phase_rank_drop_limit.has_value() && _bitVector) { + dropped_or_null = nullptr; + } if (needReScore) { - add_rescored_hits(*rs, _hits, _reRankedHits, rescorer); + add_rescored_hits(*rs, _hits, _reRankedHits, second_phase_rank_drop_limit, dropped_or_null, rescorer); } else { - add_rescored_hits(*rs, _hits, _reRankedHits, NoRescorer()); + add_rescored_hits(*rs, _hits, _reRankedHits, second_phase_rank_drop_limit, dropped_or_null, NoRescorer()); } } else { if (_unordered) { @@ -297,12 +400,20 @@ HitCollector::getResultSet() } rs->allocArray(_docIdVector.size()); if (needReScore) { - mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, rescorer); + mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, second_phase_rank_drop_limit, dropped, rescorer); } else { - mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, NoRescorer()); + mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, second_phase_rank_drop_limit, dropped, NoRescorer()); } } + if (second_phase_rank_drop_limit.has_value() && _bitVector) { + if (dropped != nullptr) { + assert(dropped->empty()); + add_bitvector_to_dropped(*dropped, {rs->getArray(), rs->getArrayUsed()}, *_bitVector); + } + _bitVector.reset(); + } + if (_bitVector) { rs->setBitOverflow(std::move(_bitVector)); } @@ -310,4 +421,10 @@ HitCollector::getResultSet() return rs; } +std::unique_ptr<ResultSet> +HitCollector::getResultSet() +{ + return get_result_set(std::nullopt, nullptr); +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/hitcollector.h b/searchlib/src/vespa/searchlib/queryeval/hitcollector.h index fe1d486ff2a..c23fb0a6ef6 100644 --- a/searchlib/src/vespa/searchlib/queryeval/hitcollector.h +++ b/searchlib/src/vespa/searchlib/queryeval/hitcollector.h @@ -8,6 +8,7 @@ #include <vespa/searchlib/common/resultset.h> #include <vespa/vespalib/util/sort.h> #include <algorithm> +#include <optional> #include <vector> namespace search::queryeval { @@ -166,6 +167,9 @@ public: const std::pair<Scores, Scores> &getRanges() const { return _ranges; } void setRanges(const std::pair<Scores, Scores> &ranges); + std::unique_ptr<ResultSet> + get_result_set(std::optional<double> second_phase_rank_drop_limit, std::vector<uint32_t>* dropped); + /** * Returns a result set based on the content of this collector. * Invoking this method will destroy the heap property of the |