summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2024-05-29 17:09:51 +0200
committerTor Egge <Tor.Egge@online.no>2024-05-29 17:09:51 +0200
commit0011b82100a989228cfdc8f9365ef71daf582828 (patch)
tree1c6dedb24e237c89a1cf6c8a5e6f0dea4b2eeeea /searchlib
parentba7973f86652e80e7b0d0d7ce1ea439a89555240 (diff)
Add hidden RerankRescorer class and use it to get second phase scores
into the result set earlier.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp66
1 files changed, 48 insertions, 18 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
index 3154f95bbe1..c1d59463ad9 100644
--- a/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/hitcollector.cpp
@@ -193,6 +193,33 @@ struct NoRescorer
};
template <typename Rescorer>
+class RerankRescorer {
+ Rescorer _rescorer;
+ using HitVector = std::vector<HitCollector::Hit>;
+ using Iterator = typename HitVector::const_iterator;
+ Iterator _reranked_cur;
+ Iterator _reranked_end;
+public:
+ RerankRescorer(const Rescorer& rescorer,
+ const HitVector& reranked_hits)
+ : _rescorer(rescorer),
+ _reranked_cur(reranked_hits.begin()),
+ _reranked_end(reranked_hits.end())
+ {
+ }
+
+ double rescore(uint32_t docid, double score) noexcept {
+ if (_reranked_cur != _reranked_end && _reranked_cur->first == docid) {
+ double result = _reranked_cur->second;
+ ++_reranked_cur;
+ return result;
+ } else {
+ return _rescorer.rescore(docid, score);
+ }
+ }
+};
+
+template <typename Rescorer>
void
add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Rescorer rescorer)
{
@@ -203,6 +230,17 @@ add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, Res
template <typename Rescorer>
void
+add_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
+{
+ if (reranked_hits.empty()) {
+ add_rescored_hits(rs, hits, rescorer);
+ } else {
+ add_rescored_hits(rs, hits, RerankRescorer(rescorer, reranked_hits));
+ }
+}
+
+template <typename Rescorer>
+void
mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, Rescorer rescorer)
{
auto hits_cur = hits.begin();
@@ -217,18 +255,14 @@ mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, c
}
}
+template <typename Rescorer>
void
-mergeHitsIntoResultSet(const std::vector<HitCollector::Hit> &hits, ResultSet &result)
+mixin_rescored_hits(ResultSet& rs, const std::vector<HitCollector::Hit>& hits, const std::vector<uint32_t>& docids, double default_value, const std::vector<HitCollector::Hit>& reranked_hits, Rescorer rescorer)
{
- uint32_t rhCur(0);
- uint32_t rhEnd(result.getArrayUsed());
- for (const auto &hit : hits) {
- while (rhCur != rhEnd && result[rhCur].getDocId() != hit.first) {
- // just set the iterators right
- ++rhCur;
- }
- assert(rhCur != rhEnd); // the hits should be a subset of the hits in ranked hit array.
- result[rhCur]._rankValue = hit.second;
+ if (reranked_hits.empty()) {
+ mixin_rescored_hits(rs, hits, docids, default_value, rescorer);
+ } else {
+ mixin_rescored_hits(rs, hits, docids, default_value, RerankRescorer(rescorer, reranked_hits));
}
}
@@ -247,9 +281,9 @@ HitCollector::getResultSet(HitRank default_value)
if ( ! _collector->isDocIdCollector() ) {
rs->allocArray(_hits.size());
if (needReScore) {
- add_rescored_hits(*rs, _hits, rescorer);
+ add_rescored_hits(*rs, _hits, _reRankedHits, rescorer);
} else {
- add_rescored_hits(*rs, _hits, NoRescorer());
+ add_rescored_hits(*rs, _hits, _reRankedHits, NoRescorer());
}
} else {
if (_unordered) {
@@ -257,16 +291,12 @@ HitCollector::getResultSet(HitRank default_value)
}
rs->allocArray(_docIdVector.size());
if (needReScore) {
- mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, rescorer);
+ mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, rescorer);
} else {
- mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, NoRescorer());
+ mixin_rescored_hits(*rs, _hits, _docIdVector, default_value, _reRankedHits, NoRescorer());
}
}
- if (!_reRankedHits.empty()) {
- mergeHitsIntoResultSet(_reRankedHits, *rs);
- }
-
if (_bitVector) {
rs->setBitOverflow(std::move(_bitVector));
}