summaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/attribute/diversity.h
blob: e42ad04ceb662dd24b7476c8465b4747f81fd31a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vespa/searchcommon/attribute/iattributevector.h>
#include <vespa/searchlib/queryeval/idiversifier.h>
#include <vespa/vespalib/datastore/entryref.h>

/**
 * This file contains low-level code used to implement diversified
 * limited attribute range searches. Terms on the form [;;100;foo;3]
 * are used to specify unbound range searches in an attribute that
 * produces a limited number of results while also ensuring
 * diversified results based on a secondary attribute.
 **/

namespace search::attribute::diversity {

template <typename ITR>
class ForwardRange
{
private:
    ITR _lower;
    ITR _upper;
public:
    class Next {
    private:
        ITR &_lower;
    public:
        Next(const Next &) = delete;
        explicit Next(ForwardRange &range) : _lower(range._lower) {}
        const ITR &get() const { return _lower; }
        ~Next() { ++_lower; }
    };
    ForwardRange(const ForwardRange &);
    ForwardRange(const ITR &lower, const ITR &upper);
    ~ForwardRange();
    bool has_next() const { return _lower != _upper; }
};

template <typename ITR>
class ReverseRange
{
private:
    ITR _lower;
    ITR _upper;
public:
    class Next {
    private:
        ITR &_upper;
    public:
        Next(const Next &) = delete;
        explicit Next(ReverseRange &range) : _upper(range._upper) { --_upper; }
        const ITR &get() const { return _upper; }
    };
    ReverseRange(const ReverseRange &);
    ReverseRange(const ITR &lower, const ITR &upper);
    ~ReverseRange();
    bool has_next() const { return _lower != _upper; }
};

class DiversityFilter : public queryeval::IDiversifier {
public:
    DiversityFilter(size_t max_total) : _max_total(max_total) {}
    size_t getMaxTotal() const { return _max_total; }
    static std::unique_ptr<DiversityFilter>
    create(const IAttributeVector &diversity_attr, size_t wanted_hits,
           size_t max_per_group,size_t cutoff_max_groups, bool cutoff_strict);
protected:
    size_t _max_total;
};

template <typename Result>
class DiversityRecorder {
private:
    DiversityFilter & _filter;
    Result &_result;
public:
    DiversityRecorder(DiversityFilter & filter, Result &result)
        : _filter(filter), _result(result)
    { }

    template <typename Item>
    void push_back(Item item) {
        if (_filter.accepted(item._key)) {
            _result.push_back(item);
        }
    }

};

template <typename DictRange, typename PostingStore, typename Result>
void diversify_2(const DictRange &range_in, const PostingStore &posting, DiversityFilter & filter,
                 Result &result, std::vector<size_t> &fragments)
{

    DiversityRecorder<Result> recorder(filter, result);
    DictRange range(range_in);
    using DataType = typename PostingStore::DataType;
    using KeyDataType = typename PostingStore::KeyDataType;
    while (range.has_next() && (result.size() < filter.getMaxTotal())) {
        typename DictRange::Next dict_entry(range);
        posting.foreach_frozen(datastore::EntryRef(dict_entry.get().getData()),
                               [&](uint32_t key, const DataType &data)
                               { recorder.push_back(KeyDataType(key, data)); });
        if (fragments.back() < result.size()) {
            fragments.push_back(result.size());
        }
    }
}

template <typename DictItr, typename PostingStore, typename Result>
void diversify(bool forward, const DictItr &lower, const DictItr &upper, const PostingStore &posting, size_t wanted_hits,
               const IAttributeVector &diversity_attr, size_t max_per_group,
               size_t cutoff_max_groups, bool cutoff_strict,
               Result &array, std::vector<size_t> &fragments)
{
    auto filter = DiversityFilter::create(diversity_attr, wanted_hits, max_per_group, cutoff_max_groups, cutoff_strict);
    if (forward) {
        diversify_2(ForwardRange<DictItr>(lower, upper), posting, *filter, array, fragments);
    } else {
        diversify_2(ReverseRange<DictItr>(lower, upper), posting, *filter, array, fragments);
    }
}

}