aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/queryeval/weak_and/rise_wand.h
blob: d4e66ec1907ebf447756067cc38c80cea4a4c29b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vespa/searchlib/queryeval/wand/weak_and_search.h>
#include <vespa/searchlib/queryeval/wand/wand_parts.h>
#include <vespa/vespalib/util/priority_queue.h>
#include <functional>

using search::queryeval::wand::DotProductScorer;
using search::queryeval::wand::TermFrequencyScorer;
using namespace search::queryeval;

namespace rise {

struct TermFreqScorer
{
    static int64_t calculateMaxScore(const wand::Term &term) {
        return TermFrequencyScorer::calculateMaxScore(term);
    }
    static int64_t calculateScore(const wand::Term &term, uint32_t docId) {
        term.search->unpack(docId);
        return term.maxScore;
    }
};

template <typename Scorer, typename Cmp>
class RiseWand : public search::queryeval::SearchIterator
{
public:
    using docid_t = uint32_t;
    using score_t = uint64_t;
    using Terms = search::queryeval::wand::Terms;
    using PostingStreamPtr = search::queryeval::SearchIterator*;

private:
    // comparator class that compares two streams. The variables a and b are
    // logically indices into the streams vector.
    class StreamComparator
    {
    private:
        const docid_t *_streamDocIds;
        //const addr_t *const *_streamPayloads;

    public:
        StreamComparator(const docid_t *streamDocIds);
        //const addr_t *const *streamPayloads);
        inline bool operator()(const uint16_t a, const uint16_t b);
    };

    // number of streams present in the query
    uint32_t _numStreams;

    // we own our substreams
    std::vector<PostingStreamPtr> _streams;

    size_t _lastPivotIdx;

    // array of current doc ids for the various streams
    docid_t *_streamDocIds;

    // two arrays of indices into the _streams vector. This is used for merge.
    // inplace_merge is not as efficient as the copy merge.
    uint16_t *_streamIndices;
    uint16_t *_streamIndicesAux;

    // comparator that compares two streams
    StreamComparator _streamComparator;

    //-------------------------------------------------------------------------
    // variables used for scoring and pruning

    size_t                           _n;
    score_t                          _limit;
    score_t                         *_streamScores;
    vespalib::PriorityQueue<score_t> _scores;
    Terms                            _terms;

    //-------------------------------------------------------------------------

    /**
     * Find the pivot feature index
     *
     * @param threshold  score threshold
     * @param pivotIdx   pivot index
     *
     * @return  whether a valid pivot index is found
     */
    bool _findPivotFeatureIdx(const score_t threshold, uint32_t &pivotIdx);

    /**
     * let the first numStreamsToMove streams in the stream
     * vector move to the next doc, and sort them.
     *
     * @param numStreamsToMove  the number of streams that should move
     */
    void _moveStreamsAndSort(const uint32_t numStreamsToMove);

    /**
     * let the first numStreamsToMove streams in the stream
     * vector move to desiredDocId or to the first docId greater than
     * desiredDocId if desiredDocId does not exist in this stream,
     * and sort them.
     *
     * @param numStreamsToMove  the number of streams that should move
     * @param desiredDocId  desired doc id
     *
     */
    void _moveStreamsToDocAndSort(const uint32_t numStreamsToMove, const docid_t desiredDocId);

    /**
     * do sort and merge for WAND
     *
     * @param numStreamsToSort  the number of streams (starting from the first one) should
     *                                           be sorted and then merge sort with the rest
     *
     */
    void _sortMerge(const uint32_t numStreamsToSort);

public:
    RiseWand(const Terms &terms, uint32_t n);
    ~RiseWand();
    void next();
    void doSeek(uint32_t docid) override;
    void doUnpack(uint32_t docid) override;
};

using TermFrequencyRiseWand = RiseWand<TermFreqScorer, std::greater_equal<uint64_t> >;
using DotProductRiseWand = RiseWand<DotProductScorer, std::greater<uint64_t> >;

} // namespacve rise