diff options
author | Tor Egge <Tor.Egge@online.no> | 2024-05-29 17:09:51 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2024-05-29 17:09:51 +0200 |
commit | 0011b82100a989228cfdc8f9365ef71daf582828 (patch) | |
tree | 1c6dedb24e237c89a1cf6c8a5e6f0dea4b2eeeea /searchlib | |
parent | ba7973f86652e80e7b0d0d7ce1ea439a89555240 (diff) |
Add hidden RerankRescorer class and use it to get second phase scores
into the result set earlier.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp | 66 |
1 files changed, 48 insertions, 18 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp index 3154f95bbe1..c1d59463ad9 100644 --- a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp @@ -193,6 +193,33 @@ struct NoRescorer }; template <typename Rescorer> +class RerankRescorer { + Rescorer _rescorer; + using HitVector = std::vector<HitCollector::Hit>; + using Iterator = typename HitVector::const_iterator; + Iterator _reranked_cur; + Iterator _reranked_end; +public: + RerankRescorer(const Rescorer& rescorer, + const HitVector& reranked_hits) + : _rescorer(rescorer), + _reranked_cur(reranked_hits.begin()), + _reranked_end(reranked_hits.end()) + { + } + + double rescore(uint32_t docid, double score) noexcept { + if (_reranked_cur != _reranked_end && _reranked_cur->first == docid) { + double result = _reranked_cur->second; + ++_reranked_cur; + return result; + } else { + return _rescorer.rescore(docid, score); + } + } +}; + +template <typename Rescorer> void add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer) { @@ -203,6 +230,17 @@ add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Res template <typename Rescorer> void +add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer) +{ + if (reranked_hits.empty()) { + add_rescored_hits(rs, hits, rescorer); + } else { + add_rescored_hits(rs, hits, RerankRescorer(rescorer, reranked_hits)); + } +} + +template <typename Rescorer> +void mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer) { auto hits_cur = hits.begin(); @@ -217,18 +255,14 @@ mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, c } } +template <typename Rescorer> void -mergeHitsIntoResultSet(const std::vector<HitCollector::Hit> &hits, ResultSet &result) +mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer) { - uint32_t rhCur(0); - uint32_t rhEnd(result.getArrayUsed()); - for (const auto &hit : hits) { - while (rhCur != rhEnd && result[rhCur].getDocId() != hit.first) { - // just set the iterators right - ++rhCur; - } - assert(rhCur != rhEnd); // the hits should be a subset of the hits in ranked hit array. - result[rhCur]._rankValue = hit.second; + if (reranked_hits.empty()) { + mixin_rescored_hits(rs, hits, docids, default_value, rescorer); + } else { + mixin_rescored_hits(rs, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits)); } } @@ -247,9 +281,9 @@ HitCollector::getResultSet(HitRank default_value) if ( ! _collector->isDocIdCollector() ) { rs->allocArray(_hits.size()); if (needReScore) { - add_rescored_hits(*rs, _hits, rescorer); + add_rescored_hits(*rs, _hits, _reRankedHits, rescorer); } else { - add_rescored_hits(*rs, _hits, NoRescorer()); + add_rescored_hits(*rs, _hits, _reRankedHits, NoRescorer()); } } else { if (_unordered) { @@ -257,16 +291,12 @@ HitCollector::getResultSet(HitRank default_value) } rs->allocArray(_docIdVector.size()); if (needReScore) { - mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, rescorer); + mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, rescorer); } else { - mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, NoRescorer()); + mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, NoRescorer()); } } - if (!_reRankedHits.empty()) { - mergeHitsIntoResultSet(_reRankedHits, *rs); - } - if (_bitVector) { rs->setBitOverflow(std::move(_bitVector)); } |