aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/predicate/predicate_tree_annotator.cpp
blob: 031b20c48f522f4a7e2bc08f7a46f69ccb635af1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "predicate_tree_annotator.h"
#include "predicate_range_expander.h"
#include "predicate_tree_analyzer.h"
#include "common.h"
#include <vespa/document/predicate/predicate.h>

using document::Predicate;
using std::map;
using std::string;
using vespalib::slime::Inspector;
using vespalib::Memory;

namespace search::predicate {

using predicate::MIN_INTERVAL;
using predicate::MAX_INTERVAL;

namespace {

class PredicateTreeAnnotatorImpl {
    uint32_t _begin;
    uint32_t _end;
    uint32_t _left_weight;
    PredicateTreeAnnotations &_result;
    uint64_t _zStar_hash;
    bool     _negated;
    bool     _final_range_used;
    const std::map<std::string, int> &_size_map;
    TreeCrumbs _crumbs;
    int64_t    _lower_bound;
    int64_t    _upper_bound;
    uint16_t   _interval_range;


    uint32_t makeMarker(uint32_t begin, uint32_t end) {
        return (begin << 16) | end;
    }
    uint32_t getCEnd() {
        if (!_final_range_used && _end == _interval_range) {
            _final_range_used = true;
            return _interval_range - 1;
        }
        return _left_weight + 1;
    }
    void addZstarIntervalIfNegated(uint32_t cEnd);

public:
    PredicateTreeAnnotatorImpl(const std::map<std::string, int> &size_map,
                               PredicateTreeAnnotations &result,
                               int64_t lower, int64_t upper, uint16_t interval_range);

    void assignIntervalMarkers(const vespalib::slime::Inspector &in);
};

void
PredicateTreeAnnotatorImpl::addZstarIntervalIfNegated(uint32_t cEnd) {
    if (_negated) {
        auto it = _result.interval_map.find(_zStar_hash);
        if (it == _result.interval_map.end()) {
            it = _result.interval_map.insert(make_pair(_zStar_hash, std::vector<Interval>())).first;
            _result.features.push_back(_zStar_hash);
        }
        auto &intervals = it->second;
        intervals.push_back(Interval{ makeMarker(cEnd, _begin - 1) });
        if (_end - cEnd != 1) {
            intervals.push_back(Interval{ makeMarker(0, _end) });
        }
        _left_weight += 1;
    }
}

PredicateTreeAnnotatorImpl::PredicateTreeAnnotatorImpl(const map<string, int> &size_map, PredicateTreeAnnotations &result,
                                                       int64_t lower_bound, int64_t upper_bound, uint16_t interval_range)
    : _begin(MIN_INTERVAL),
      _end(interval_range),
      _left_weight(0),
      _result(result),
      _zStar_hash(Constants::z_star_compressed_hash),
      _negated(false),
      _final_range_used(false),
      _size_map(size_map),
      _crumbs(),
      _lower_bound(lower_bound),
      _upper_bound(upper_bound),
      _interval_range(interval_range)
{
}

long
getType(const Inspector &in, bool negated) {
    long type = in[Predicate::NODE_TYPE].asLong();
    if (negated) {
        if (type == Predicate::TYPE_CONJUNCTION) {
            return Predicate::TYPE_DISJUNCTION;
        } else if (type == Predicate::TYPE_DISJUNCTION) {
            return Predicate::TYPE_CONJUNCTION;
        }
    }
    return type;
}

void
PredicateTreeAnnotatorImpl::assignIntervalMarkers(const Inspector &in) {
    const Inspector& in_children = in[Predicate::CHILDREN];
    switch (getType(in, _negated)) {
    case Predicate::TYPE_CONJUNCTION: {
        int crumb_size = _crumbs.size();
        uint32_t curr = _begin;
        size_t child_count = in_children.children();
        uint32_t begin = _begin;
        uint32_t end = _end;
        for (size_t i = 0; i < child_count; ++i) {
            _crumbs.setChild(i, 'a');
            if (i == child_count - 1) {  // Last child (may also be the only?)
                _begin = curr;
                _end = end;
                assignIntervalMarkers(in_children[i]);
                // No need to update/touch curr
            } else if (i == 0) {  // First child
                auto it = _size_map.find(_crumbs.getCrumb());
                assert (it != _size_map.end());
                uint32_t child_size = it->second;
                uint32_t next = _left_weight + child_size + 1;
                _begin = curr;
                _end = next - 1;
                assignIntervalMarkers(in_children[i]);
                curr = next;
            } else {  // Middle children
                auto it = _size_map.find(_crumbs.getCrumb());
                assert (it != _size_map.end());
                uint32_t child_size = it->second;
                uint32_t next = curr + child_size;
                _begin = curr;
                _end = next - 1;
                assignIntervalMarkers(in_children[i]);
                curr = next;
            }
            _crumbs.resize(crumb_size);
        }
        _begin = begin;
        break;
    }
    case Predicate::TYPE_DISJUNCTION: {
        // All OR children will have the same {begin, end} values, and
        // the values will be same as that of the parent OR node
        int crumb_size = _crumbs.size();
        for (size_t i = 0; i < in_children.children(); ++i) {
            _crumbs.setChild(i, 'o');
            assignIntervalMarkers(in_children[i]);
            _crumbs.resize(crumb_size);
        }
        break;
    }
    case Predicate::TYPE_FEATURE_SET: {
        uint32_t cEnd = _negated? getCEnd() : 0;
        Memory label_mem = in[Predicate::KEY].asString();
        string label(label_mem.data, label_mem.size);
        label.push_back('=');
        const size_t prefix_size = label.size();
        const Inspector& in_set = in[Predicate::SET];
        for (size_t i = 0; i < in_set.children(); ++i) {
            Memory value = in_set[i].asString();
            label.resize(prefix_size);
            label.append(value.data, value.size);
            uint64_t hash = PredicateHash::hash64(label);
            if (_result.interval_map.find(hash) == _result.interval_map.end()) {
                _result.features.push_back(hash);
            }
            _result.interval_map[hash].push_back({ makeMarker(_begin, _negated? cEnd : _end) });
        }
        addZstarIntervalIfNegated(cEnd);
        _left_weight += 1;
        break;
    }
    case Predicate::TYPE_FEATURE_RANGE: {
        uint32_t cEnd = _negated? getCEnd() : 0;
        const Inspector& in_hashed_partitions = in[Predicate::HASHED_PARTITIONS];
        for (size_t i = 0; i < in_hashed_partitions.children(); ++i) {
            uint64_t hash = in_hashed_partitions[i].asLong();
            _result.interval_map[hash].push_back({ makeMarker(_begin, _negated? cEnd : _end) });
        }
        const Inspector& in_hashed_edges = in[Predicate::HASHED_EDGE_PARTITIONS];
        for (size_t i = 0; i < in_hashed_edges.children(); ++i){
            const Inspector& child = in_hashed_edges[i];
            uint64_t hash = child[Predicate::HASH].asLong();
            uint32_t payload = child[Predicate::PAYLOAD].asLong();
            _result.bounds_map[hash].push_back({ makeMarker(_begin, _negated? cEnd : _end), payload });
        }
        uint32_t hash_count = in_hashed_partitions.children() + in_hashed_edges.children();
        if (hash_count < 3) {  // three features takes more space than one stored range.
            for (size_t i = 0; i < in_hashed_partitions.children(); ++i) {
                _result.features.push_back(in_hashed_partitions[i].asLong());
            }
            for (size_t i = 0; i < in_hashed_edges.children(); ++i) {
                const Inspector& child = in_hashed_edges[i];
                uint64_t hash = child[Predicate::HASH].asLong();
                _result.features.push_back(hash);
            }
        } else {
            const Inspector& in_min = in[Predicate::RANGE_MIN];
            const Inspector& in_max = in[Predicate::RANGE_MAX];
            bool has_min = in_min.valid();
            bool has_max = in_max.valid();
            _result.range_features.push_back(
                    {in[Predicate::KEY].asString(),
                     has_min? in_min.asLong() : _lower_bound,
                     has_max? in_max.asLong() : _upper_bound
                     });
        }
        addZstarIntervalIfNegated(cEnd);
        _left_weight += 1;
        break;
    }
    case Predicate::TYPE_NEGATION:
        _negated = !_negated;
        assignIntervalMarkers(in_children[0]);
        _negated = !_negated;
        break;
    }  // switch
}
}  // namespace


PredicateTreeAnnotations::PredicateTreeAnnotations(uint32_t mf, uint16_t ir)
    : min_feature(mf), interval_range(ir)
{}

PredicateTreeAnnotations::~PredicateTreeAnnotations() = default;

void
PredicateTreeAnnotator::annotate(const Inspector &in, PredicateTreeAnnotations &result, int64_t lower, int64_t upper)
{
    PredicateTreeAnalyzer analyzer(in);
    uint32_t min_feature = static_cast<uint32_t>(analyzer.getMinFeature());
    // Size is as interval range (tree size is lower bound for interval range)
    int size = analyzer.getSize();
    assert(size <= UINT16_MAX && size > 0);
    uint16_t interval_range = static_cast<uint16_t>(size);

    PredicateTreeAnnotatorImpl annotator(analyzer.getSizeMap(), result, lower, upper, interval_range);
    annotator.assignIntervalMarkers(in);
    result.min_feature = min_feature;
    result.interval_range = interval_range;
}

}