aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-04-26 15:15:45 +0200
committerGitHub <noreply@github.com>2024-04-26 15:15:45 +0200
commit970e8747037a91dad423f073e08be2612bdeb71a (patch)
tree446664ddbb21e1a2bf4159789909d631b4896197
parenta0c71c001b3efc96276ff518c96860eff52d6c24 (diff)
parent87aa35392382a871a643248cb3d6efd05e2c4f4b (diff)
Merge pull request #31057 from vespa-engine/balder/allow-stat-in-scorer
Allow scorer for wand to carry state
-rw-r--r--searchlib/src/tests/queryeval/weak_and/rise_wand.h31
-rw-r--r--searchlib/src/tests/queryeval/weak_and/rise_wand.hpp35
-rw-r--r--searchlib/src/tests/queryeval/weak_and_scorers/weak_and_scorers_test.cpp10
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/wand/wand_parts.h36
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/wand/weak_and_search.cpp5
5 files changed, 50 insertions, 67 deletions
diff --git a/searchlib/src/tests/queryeval/weak_and/rise_wand.h b/searchlib/src/tests/queryeval/weak_and/rise_wand.h
index d4e66ec1907..4c7be54a6f0 100644
--- a/searchlib/src/tests/queryeval/weak_and/rise_wand.h
+++ b/searchlib/src/tests/queryeval/weak_and/rise_wand.h
@@ -15,8 +15,12 @@ namespace rise {
struct TermFreqScorer
{
- static int64_t calculateMaxScore(const wand::Term &term) {
- return TermFrequencyScorer::calculateMaxScore(term);
+ [[no_unique_address]] TermFrequencyScorer _termFrequencyScorer;
+ TermFreqScorer() noexcept
+ : _termFrequencyScorer()
+ { }
+ int64_t calculateMaxScore(const wand::Term &term) const noexcept {
+ return _termFrequencyScorer.calculateMaxScore(term);
}
static int64_t calculateScore(const wand::Term &term, uint32_t docId) {
term.search->unpack(docId);
@@ -43,9 +47,13 @@ private:
//const addr_t *const *_streamPayloads;
public:
- StreamComparator(const docid_t *streamDocIds);
+ explicit StreamComparator(const docid_t *streamDocIds) noexcept
+ : _streamDocIds(streamDocIds)
+ { }
//const addr_t *const *streamPayloads);
- inline bool operator()(const uint16_t a, const uint16_t b);
+ bool operator()(const uint16_t a, const uint16_t b) const noexcept {
+ return (_streamDocIds[a] < _streamDocIds[b]);
+ }
};
// number of streams present in the query
@@ -66,6 +74,7 @@ private:
// comparator that compares two streams
StreamComparator _streamComparator;
+ [[no_unique_address]] Scorer _scorer;
//-------------------------------------------------------------------------
// variables used for scoring and pruning
@@ -86,7 +95,7 @@ private:
*
* @return whether a valid pivot index is found
*/
- bool _findPivotFeatureIdx(const score_t threshold, uint32_t &pivotIdx);
+ bool _findPivotFeatureIdx(score_t threshold, uint32_t &pivotIdx);
/**
* let the first numStreamsToMove streams in the stream
@@ -94,7 +103,7 @@ private:
*
* @param numStreamsToMove the number of streams that should move
*/
- void _moveStreamsAndSort(const uint32_t numStreamsToMove);
+ void _moveStreamsAndSort(uint32_t numStreamsToMove);
/**
* let the first numStreamsToMove streams in the stream
@@ -106,7 +115,7 @@ private:
* @param desiredDocId desired doc id
*
*/
- void _moveStreamsToDocAndSort(const uint32_t numStreamsToMove, const docid_t desiredDocId);
+ void _moveStreamsToDocAndSort(uint32_t numStreamsToMove, docid_t desiredDocId);
/**
* do sort and merge for WAND
@@ -115,18 +124,18 @@ private:
* be sorted and then merge sort with the rest
*
*/
- void _sortMerge(const uint32_t numStreamsToSort);
+ void _sortMerge(uint32_t numStreamsToSort);
public:
RiseWand(const Terms &terms, uint32_t n);
- ~RiseWand();
+ ~RiseWand() override;
void next();
void doSeek(uint32_t docid) override;
void doUnpack(uint32_t docid) override;
};
-using TermFrequencyRiseWand = RiseWand<TermFreqScorer, std::greater_equal<uint64_t> >;
-using DotProductRiseWand = RiseWand<DotProductScorer, std::greater<uint64_t> >;
+using TermFrequencyRiseWand = RiseWand<TermFreqScorer, std::greater_equal<> >;
+using DotProductRiseWand = RiseWand<DotProductScorer, std::greater<> >;
} // namespacve rise
diff --git a/searchlib/src/tests/queryeval/weak_and/rise_wand.hpp b/searchlib/src/tests/queryeval/weak_and/rise_wand.hpp
index 32e17014f98..c477be5cc62 100644
--- a/searchlib/src/tests/queryeval/weak_and/rise_wand.hpp
+++ b/searchlib/src/tests/queryeval/weak_and/rise_wand.hpp
@@ -19,6 +19,7 @@ RiseWand<Scorer, Cmp>::RiseWand(const Terms &terms, uint32_t n)
_streamIndices(new uint16_t[terms.size()]),
_streamIndicesAux(new uint16_t[terms.size()]),
_streamComparator(_streamDocIds),
+ _scorer(),
_n(n),
_limit(1),
_streamScores(new score_t[terms.size()]),
@@ -26,7 +27,7 @@ RiseWand<Scorer, Cmp>::RiseWand(const Terms &terms, uint32_t n)
_terms(terms)
{
for (size_t i = 0; i < terms.size(); ++i) {
- _terms[i].maxScore = Scorer::calculateMaxScore(terms[i]);
+ _terms[i].maxScore = _scorer.calculateMaxScore(terms[i]);
_streamScores[i] = _terms[i].maxScore;
_streams.push_back(terms[i].search);
}
@@ -46,8 +47,8 @@ RiseWand<Scorer, Cmp>::RiseWand(const Terms &terms, uint32_t n)
template <typename Scorer, typename Cmp>
RiseWand<Scorer, Cmp>::~RiseWand()
{
- for (size_t i = 0; i < _streams.size(); ++i) {
- delete _streams[i];
+ for (auto * stream : _streams) {
+ delete stream;
}
delete [] _streamScores;
delete [] _streamIndicesAux;
@@ -137,8 +138,7 @@ RiseWand<Scorer, Cmp>::_moveStreamsAndSort(const uint32_t numStreamsToMove)
template <typename Scorer, typename Cmp>
void
-RiseWand<Scorer, Cmp>::_moveStreamsToDocAndSort(const uint32_t numStreamsToMove,
- const docid_t desiredDocId)
+RiseWand<Scorer, Cmp>::_moveStreamsToDocAndSort(const uint32_t numStreamsToMove, const docid_t desiredDocId)
{
for (uint32_t i=0; i<numStreamsToMove; ++i) {
_streams[_streamIndices[i]]->seek(desiredDocId);
@@ -195,7 +195,7 @@ RiseWand<Scorer, Cmp>::doUnpack(uint32_t docid)
{
score_t score = 0;
for (size_t i = 0; i <= _lastPivotIdx; ++i) {
- score += Scorer::calculateScore(_terms[_streamIndices[i]], docid);
+ score += _scorer.calculateScore(_terms[_streamIndices[i]], docid);
}
if (_scores.size() < _n || _scores.front() < score) {
_scores.push(score);
@@ -208,28 +208,5 @@ RiseWand<Scorer, Cmp>::doUnpack(uint32_t docid)
}
}
-/**
- ************ BEGIN STREAM COMPARTOR *********************
- */
-template <typename Scorer, typename Cmp>
-RiseWand<Scorer, Cmp>::StreamComparator::StreamComparator(
- const docid_t *streamDocIds)
- : _streamDocIds(streamDocIds)
-{
-}
-
-template <typename Scorer, typename Cmp>
-inline bool
-RiseWand<Scorer, Cmp>::StreamComparator::operator()(const uint16_t a,
- const uint16_t b)
-{
- if (_streamDocIds[a] < _streamDocIds[b]) return true;
- return false;
-}
-
-/**
- ************ END STREAM COMPARTOR *********************
- */
-
} // namespace rise
diff --git a/searchlib/src/tests/queryeval/weak_and_scorers/weak_and_scorers_test.cpp b/searchlib/src/tests/queryeval/weak_and_scorers/weak_and_scorers_test.cpp
index 528e117f976..e1f3f0805d9 100644
--- a/searchlib/src/tests/queryeval/weak_and_scorers/weak_and_scorers_test.cpp
+++ b/searchlib/src/tests/queryeval/weak_and_scorers/weak_and_scorers_test.cpp
@@ -25,18 +25,18 @@ struct TestIterator : public SearchIterator
_useInfo(useInfo),
_unpackDocId(0)
{}
- virtual void doSeek(uint32_t docId) override {
+ void doSeek(uint32_t docId) override {
(void) docId;
}
- virtual void doUnpack(uint32_t docId) override {
+ void doUnpack(uint32_t docId) override {
_unpackDocId = docId;
_tfmd.appendPosition(TermFieldMatchDataPosition(0, 0, _termWeight, 1));
}
- virtual const PostingInfo *getPostingInfo() const override {
- return (_useInfo ? &_info : NULL);
+ const PostingInfo *getPostingInfo() const override {
+ return (_useInfo ? &_info : nullptr);
}
static UP create(int32_t maxWeight, int32_t termWeight, bool useInfo) {
- return UP(new TestIterator(maxWeight, termWeight, useInfo));
+ return std::make_unique<TestIterator>(maxWeight, termWeight, useInfo);
}
};
diff --git a/searchlib/src/vespa/searchlib/queryeval/wand/wand_parts.h b/searchlib/src/vespa/searchlib/queryeval/wand/wand_parts.h
index ed8d4b4e4ac..4e781f8497b 100644
--- a/searchlib/src/vespa/searchlib/queryeval/wand/wand_parts.h
+++ b/searchlib/src/vespa/searchlib/queryeval/wand/wand_parts.h
@@ -163,7 +163,7 @@ public:
~VectorizedState();
template <typename Scorer, typename Input>
- std::vector<ref_t> init_state(const Input &input, uint32_t docIdLimit);
+ std::vector<ref_t> init_state(const Input &input, const Scorer & scorer, uint32_t docIdLimit);
docid_t *docId() { return &(_docId[0]); }
const int32_t *weight() const { return &(_weight[0]); }
@@ -202,14 +202,14 @@ VectorizedState<IteratorPack>::operator=(VectorizedState &&) noexcept = default;
template <typename IteratorPack>
template <typename Scorer, typename Input>
std::vector<ref_t>
-VectorizedState<IteratorPack>::init_state(const Input &input, uint32_t docIdLimit) {
+VectorizedState<IteratorPack>::init_state(const Input &input, const Scorer & scorer, uint32_t docIdLimit) {
std::vector<ref_t> order;
std::vector<score_t> max_scores;
order.reserve(input.size());
max_scores.reserve(input.size());
for (size_t i = 0; i < input.size(); ++i) {
order.push_back(i);
- max_scores.push_back(Scorer::calculate_max_score(input, i));
+ max_scores.push_back(scorer.calculate_max_score(input, i));
}
std::sort(order.begin(), order.end(), MaxSkipOrder<Input>(docIdLimit, input, max_scores));
_docId = assemble([&input](ref_t ref){ return input.get_initial_docid(ref); }, order);
@@ -238,7 +238,7 @@ private:
public:
template <typename Scorer>
- VectorizedIteratorTerms(const Terms &t, const Scorer &, uint32_t docIdLimit,
+ VectorizedIteratorTerms(const Terms &t, const Scorer & scorer, uint32_t docIdLimit,
fef::MatchData::UP childrenMatchData);
VectorizedIteratorTerms(VectorizedIteratorTerms &&) noexcept;
VectorizedIteratorTerms & operator=(VectorizedIteratorTerms &&) noexcept;
@@ -250,11 +250,11 @@ public:
};
template <typename Scorer>
-VectorizedIteratorTerms::VectorizedIteratorTerms(const Terms &t, const Scorer &, uint32_t docIdLimit,
+VectorizedIteratorTerms::VectorizedIteratorTerms(const Terms &t, const Scorer & scorer, uint32_t docIdLimit,
fef::MatchData::UP childrenMatchData)
: _terms()
{
- std::vector<ref_t> order = init_state<Scorer>(TermInput(t), docIdLimit);
+ std::vector<ref_t> order = init_state<Scorer>(TermInput(t), scorer, docIdLimit);
_terms = assemble([&t](ref_t ref){ return t[ref]; }, order);
iteratorPack() = SearchIteratorPack(assemble([&t](ref_t ref){ return t[ref].search; }, order),
assemble([&t](ref_t ref){ return t[ref].matchData; }, order),
@@ -268,10 +268,10 @@ struct VectorizedAttributeTerms : VectorizedState<DocidWithWeightIteratorPack> {
VectorizedAttributeTerms(const std::vector<int32_t> &weights,
const std::vector<IDirectPostingStore::LookupResult> &dict_entries,
const IDocidWithWeightPostingStore &attr,
- const Scorer &,
+ const Scorer & scorer,
docid_t docIdLimit)
{
- std::vector<ref_t> order = init_state<Scorer>(AttrInput(weights, dict_entries), docIdLimit);
+ std::vector<ref_t> order = init_state<Scorer>(AttrInput(weights, dict_entries), scorer, docIdLimit);
std::vector<DocidWithWeightIterator> iterators;
iterators.reserve(order.size());
for (size_t i = 0; i < order.size(); ++i) {
@@ -398,16 +398,16 @@ DualHeap<FutureHeap, PastHeap>::stringify() const {
struct TermFrequencyScorer
{
// weight * idf, scaled to fixedpoint
- static score_t calculateMaxScore(double estHits, double weight) noexcept {
+ score_t calculateMaxScore(double estHits, double weight) const noexcept {
return (score_t) (TermFrequencyScorer_TERM_SCORE_FACTOR * weight / (1.0 + log(1.0 + (estHits / 1000.0))));
}
- static score_t calculateMaxScore(const Term &term) noexcept {
+ score_t calculateMaxScore(const Term &term) const noexcept {
return calculateMaxScore(term.estHits, term.weight) + 1;
}
template <typename Input>
- static score_t calculate_max_score(const Input &input, ref_t ref) {
+ score_t calculate_max_score(const Input &input, ref_t ref) const noexcept {
return calculateMaxScore(input.get_est_hits(ref), input.get_weight(ref)) + 1;
}
};
@@ -521,10 +521,10 @@ private:
}
template <typename VectorizedTerms, typename Heaps, typename Scorer, typename AboveThreshold>
- bool check_present_score(VectorizedTerms &terms, Heaps &heaps, score_t &max_score, const Scorer &, AboveThreshold &&aboveThreshold) {
+ bool check_present_score(VectorizedTerms &terms, Heaps &heaps, score_t &max_score, const Scorer & scorer, AboveThreshold &&aboveThreshold) {
ref_t *end = heaps.present_end();
for (ref_t *ref = heaps.present_begin(); ref != end; ++ref) {
- score_t term_score = Scorer::calculateScore(terms, *ref, _candidate);
+ score_t term_score = scorer.calculateScore(terms, *ref, _candidate);
_partial_score += term_score;
max_score -= (terms.maxScore(*ref) - term_score);
if (!aboveThreshold(max_score)) {
@@ -535,11 +535,11 @@ private:
}
template <typename VectorizedTerms, typename Heaps, typename Scorer, typename AboveThreshold>
- bool check_past_score(VectorizedTerms &terms, Heaps &heaps, score_t &max_score, const Scorer &, AboveThreshold &&aboveThreshold) {
+ bool check_past_score(VectorizedTerms &terms, Heaps &heaps, score_t &max_score, const Scorer & scorer, AboveThreshold &&aboveThreshold) {
while (heaps.has_past() && !aboveThreshold(_partial_score)) {
heaps.pop_past();
if (step_term(terms, heaps.last_present())) {
- score_t term_score = Scorer::calculateScore(terms, heaps.last_present(), _candidate);
+ score_t term_score = scorer.calculateScore(terms, heaps.last_present(), _candidate);
_partial_score += term_score;
max_score -= (terms.maxScore(heaps.last_present()) - term_score);
} else {
@@ -618,7 +618,7 @@ public:
}
template <typename VectorizedTerms, typename Heaps, typename Scorer, typename AboveThreshold>
- bool check_score(VectorizedTerms &terms, Heaps &heaps, Scorer &&scorer, AboveThreshold &&aboveThreshold) {
+ bool check_score(VectorizedTerms &terms, Heaps &heaps, const Scorer &scorer, AboveThreshold &&aboveThreshold) {
_partial_score = 0;
score_t max_score = _maxUpperBound;
if (check_present_score(terms, heaps, max_score, scorer, aboveThreshold)) {
@@ -630,12 +630,12 @@ public:
}
template <typename VectorizedTerms, typename Heaps, typename Scorer>
- score_t get_full_score(VectorizedTerms &terms, Heaps &heaps, Scorer &&) {
+ score_t get_full_score(VectorizedTerms &terms, Heaps &heaps, const Scorer & scorer) {
score_t score = _partial_score;
while (heaps.has_past()) {
heaps.pop_any_past();
if (step_term(terms, heaps.last_present())) {
- score += Scorer::calculateScore(terms, heaps.last_present(), _candidate);
+ score += scorer.calculateScore(terms, heaps.last_present(), _candidate);
} else {
evict_last_present(terms, heaps);
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/wand/weak_and_search.cpp b/searchlib/src/vespa/searchlib/queryeval/wand/weak_and_search.cpp
index 375a6598b49..04b1cb75da4 100644
--- a/searchlib/src/vespa/searchlib/queryeval/wand/weak_and_search.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/wand/weak_and_search.cpp
@@ -43,10 +43,7 @@ private:
public:
WeakAndSearchLR(const Terms &terms, uint32_t n)
- : _terms(terms,
- TermFrequencyScorer(),
- 0,
- fef::MatchData::UP()),
+ : _terms(terms, TermFrequencyScorer(), 0, {}),
_heaps(DocIdOrder(_terms.docId()), _terms.size()),
_algo(),
_threshold(1),