aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/attribute/diversity.h
blob: dff658d99d7a69c712a5fe1b3f78b62140b7078f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include "singleenumattribute.h"
#include "singlenumericattribute.h"
#include <vespa/vespalib/stllike/hash_map.h>

/**
 * This file contains low-level code used to implement diversified
 * limited attribute range searches. Terms on the form [;;100;foo;3]
 * are used to specify unbound range searches in an attribute that
 * produces a limited number of results while also ensuring
 * diversified results based on a secondary attribute.
 **/

namespace search {
namespace attribute {
namespace diversity {

template <typename ITR>
class ForwardRange
{
private:
    ITR _lower;
    ITR _upper;
public:
    class Next {
    private:
        ITR &_lower;
    public:
        Next(const Next &) = delete;
        explicit Next(ForwardRange &range) : _lower(range._lower) {}
        const ITR &get() const { return _lower; }
        ~Next() { ++_lower; }
    };
    ForwardRange(const ForwardRange &);
    ForwardRange(const ITR &lower, const ITR &upper);
    ~ForwardRange();
    bool has_next() const { return _lower != _upper; }
};

template <typename ITR>
class ReverseRange
{
private:
    ITR _lower;
    ITR _upper;
public:
    class Next {
    private:
        ITR &_upper;
    public:
        Next(const Next &) = delete;
        explicit Next(ReverseRange &range) : _upper(range._upper) { --_upper; }
        const ITR &get() const { return _upper; }
    };
    ReverseRange(const ReverseRange &);
    ReverseRange(const ITR &lower, const ITR &upper);
    ~ReverseRange();
    bool has_next() const { return _lower != _upper; }
};

template <typename T>
struct FetchNumberFast {
    const T * const attr;
    typedef typename T::LoadedValueType ValueType;
    FetchNumberFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const T *>(&attr_in)) {}
    ValueType get(uint32_t docid) const { return attr->getFast(docid); }
    bool valid() const { return (attr != nullptr); }
};

struct FetchEnumFast {
    const SingleValueEnumAttributeBase * const attr;
    typedef uint32_t ValueType;
    FetchEnumFast(const IAttributeVector &attr_in) : attr(dynamic_cast<const SingleValueEnumAttributeBase *>(&attr_in)) {}
    ValueType get(uint32_t docid) const { return attr->getE(docid); }
    bool valid() const { return (attr != nullptr); }
};

struct FetchEnum {
    const IAttributeVector &attr;
    typedef uint32_t ValueType;
    FetchEnum(const IAttributeVector &attr_in) : attr(attr_in) {}
    ValueType get(uint32_t docid) const { return attr.getEnum(docid); }
};

struct FetchInteger {
    const IAttributeVector &attr;
    typedef int64_t ValueType;
    FetchInteger(const IAttributeVector &attr_in) : attr(attr_in) {}
    ValueType get(uint32_t docid) const { return attr.getInt(docid); }
};

struct FetchFloat {
    const IAttributeVector &attr;
    typedef double ValueType;
    FetchFloat(const IAttributeVector &attr_in) : attr(attr_in) {}
    ValueType get(uint32_t docid) const { return attr.getFloat(docid); }
};

template <typename Fetcher, typename Result>
class DiversityFilter {
private:
    size_t _total_count;
    size_t _max_total;
    const Fetcher &_diversity;
    size_t _max_per_group;
    size_t _cutoff_max_groups;
    bool   _cutoff_strict;

    typedef vespalib::hash_map<typename Fetcher::ValueType, uint32_t> Diversity;
    Diversity _seen;
    Result &_result;
public:
    DiversityFilter(const Fetcher &diversity, size_t max_per_group,
                    size_t cutoff_max_groups, bool cutoff_strict,
                    Result &result, size_t max_total)
        : _total_count(0), _max_total(max_total), _diversity(diversity), _max_per_group(max_per_group),
          _cutoff_max_groups(cutoff_max_groups), _cutoff_strict(cutoff_strict),
          _seen(std::min(cutoff_max_groups, 10000ul)*3), _result(result)
    { }
    template <typename Item>
    void push_back(Item item) {
        if (_total_count < _max_total) {
            if ((_seen.size() < _cutoff_max_groups) || _cutoff_strict) {
                typename Fetcher::ValueType group = _diversity.get(item._key);
                if (_seen.size() < _cutoff_max_groups) {
                    conditional_add(_seen[group], item);
                } else {
                    auto found = _seen.find(group);
                    if (found == _seen.end()) {
                        add(item);
                    } else {
                        conditional_add(found->second, item);
                    }
                }
            } else if ( !_cutoff_strict) {
                add(item);
            }
        }
    }
private:
    template <typename Item>
    void add(Item item) {
        ++_total_count;
        _result.push_back(item);
    }
    template <typename Item>
    void conditional_add(uint32_t & group_count, Item item) {
        if (group_count  < _max_per_group) {
            ++group_count;
            add(item);
        }
    }
};

template <typename DictRange, typename PostingStore, typename Fetcher, typename Result>
void diversify_3(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits,
                 const Fetcher &diversity, size_t max_per_group,
                 size_t cutoff_max_groups, bool cutoff_strict,
                 Result &result, std::vector<size_t> &fragments)
{
    DictRange range(range_in);
    using DataType = typename PostingStore::DataType;
    using KeyDataType = typename PostingStore::KeyDataType;
    DiversityFilter<Fetcher, Result> filter(diversity, max_per_group, cutoff_max_groups, cutoff_strict, result, wanted_hits);
    while (range.has_next() && (result.size() < wanted_hits)) {
        typename DictRange::Next dict_entry(range);
        posting.foreach_frozen(dict_entry.get().getData(),
                               [&](uint32_t key, const DataType &data)
                               { filter.push_back(KeyDataType(key, data)); });
        if (fragments.back() < result.size()) {
            fragments.push_back(result.size());
        }
    }
}

template <typename DictRange, typename PostingStore, typename Result>
void diversify_2(const DictRange &range_in, const PostingStore &posting, size_t wanted_hits,
                 const IAttributeVector &diversity_attr, size_t max_per_group,
                 size_t cutoff_max_groups, bool cutoff_strict,
                 Result &result, std::vector<size_t> &fragments)
{
    if (diversity_attr.hasEnum()) { // must handle enum first
        FetchEnumFast fastEnum(diversity_attr);
        if (fastEnum.valid()) {
            diversify_3(range_in, posting, wanted_hits, fastEnum, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        } else {
            diversify_3(range_in, posting, wanted_hits, FetchEnum(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        }
    } else if (diversity_attr.isIntegerType()) {
        FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int32_t> > > fastInt32(diversity_attr);
        FetchNumberFast<SingleValueNumericAttribute<IntegerAttributeTemplate<int64_t> > > fastInt64(diversity_attr);
        if (fastInt32.valid()) {
            diversify_3(range_in, posting, wanted_hits, fastInt32, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        } else if (fastInt64.valid()) {
            diversify_3(range_in, posting, wanted_hits, fastInt64, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        } else {
            diversify_3(range_in, posting, wanted_hits, FetchInteger(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        }
    } else if (diversity_attr.isFloatingPointType()) {
        FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<float> > > fastFloat(diversity_attr);
        FetchNumberFast<SingleValueNumericAttribute<FloatingPointAttributeTemplate<double> > > fastDouble(diversity_attr);
        if (fastFloat.valid()) {
            diversify_3(range_in, posting, wanted_hits, fastFloat, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        } else if (fastDouble.valid()) {
            diversify_3(range_in, posting, wanted_hits, fastDouble, max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        } else {
            diversify_3(range_in, posting, wanted_hits, FetchFloat(diversity_attr), max_per_group, cutoff_max_groups, cutoff_strict, result, fragments);
        }
    }
}

template <typename DictItr, typename PostingStore, typename Result>
void diversify(bool forward, const DictItr &lower, const DictItr &upper, const PostingStore &posting, size_t wanted_hits,
               const IAttributeVector &diversity_attr, size_t max_per_group,
               size_t cutoff_max_groups, bool cutoff_strict,
               Result &array, std::vector<size_t> &fragments)
{
    if (forward) {
        diversify_2(ForwardRange<DictItr>(lower, upper), posting, wanted_hits,
                    diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments);
    } else {
        diversify_2(ReverseRange<DictItr>(lower, upper), posting, wanted_hits,
                    diversity_attr, max_per_group, cutoff_max_groups, cutoff_strict, array, fragments);
    }
}

} // namespace search::attribute::diversity
} // namespace search::attribute
} // namespace search