summaryrefslogtreecommitdiffstats
path: root/staging_vespalib/src/vespa/vespalib/stllike/avl.h
blob: be8ccbd3497d0636899dfae35773144e5008df4d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once

#include <vector>
#include <iterator>
#include <bits/stl_algo.h>
#include <bits/stl_function.h>

#include "hash_fun.h"

namespace vespalib {

/**
   Yet another avl implementation. This one is justified by
   different memory management.

   The interface is tried to keep similar to stl version. However the
   are some major differences.  In order to avoid an allocation /
   deallocation for every object inserted / erased, it stores all
   objects in a std::vector. This should significantly speed up
   things. However it does remove properties that the stl versions
   have. The most obvious is that insert might invalidate iterators.
   That is due to possible resizing of memory store. That can be
   avoided by using a deque. However a deque is more complex and will
   be slower. Since speed is here preferred over functionality that is
   not yet done.
   The same trick could be done for tree based map/set.
   An entry can be valid or invalid(not set, or erased).
   The entry contains the element + a next index.

   After selecting the proper prime number for modulo operation, the
   vector reserves requested space. Resize (triggered by full vector)
   doubles size. Then the first 'prime' entries are initialized to
   invalid. Then the objects are filled in.  If the selected bucket is
   invalid the element is inserted directly. If not it is
   'push_back'ed, on the vector and linked in after the head element
   for that bucket. Erased elements are marked invalid, but not
   reused. They are reclaimed on resize.

   Advantage:
   - Significantly faster insert on average. No alloc/dealloc per element.
   Disadvantage:
   - insert spikes due to possible resize.
   - not fully stl-compliant.
   Conclusion:
   - Probably very good for typical use. Fx duplicate removal.
   Advice:
   - If you know how many elements you are going to put in, construct
   the hash with 2 times the amount. Since a hash will never be
   fully filled.
   ( hash_set<T>(2*num_expected_elements) ).
**/

class avl_base
{
public:
    typedef unsigned int next_t;
};

template< typename Key, typename Value, typename Compare, typename KeyExtract >
class avl : public avl_base
{
protected:
    class Node {
    public:
        enum {npos=-1u, invalid=-2u};
        Node() : _parent(npos), _left(npos), _right(npos) { }
        Node(const Value & node, next_t parent=npos, next_t left=npos, next_t right=npos) : _parent(parent), _left(left), _right(right), _node(node) { }
        Value & getValue()             { return _node; }
        const Value & getValue() const { return _node; }
        next_t getParent()       const { return _parent; }
        void setParent(next_t v)       { _parent = v; }
        next_t getLeft()         const { return _left; }
        void setLeft(next_t v)         { _left = v; }
        next_t getRight()        const { return _right; }
        void setRight(next_t v)        { _right = v; }
    private:
        next_t  _parent;
        next_t  _left;
        next_t  _right;
        Value   _node;
    };
    typedef std::vector<Node> NodeStore;
    virtual void move(const NodeStore & oldStore);
public:
    class iterator {
    public:
        iterator(avl * avl, next_t start) : _node(start), _avl(avl) { }
        Value & operator * ()  const { return _avl->get(_node); }
        Value * operator -> () const { return & _avl->get(_node); }
        iterator & operator ++ () {
            _node = _avl->getNextRight(_node);
            return *this;
        }
        iterator operator ++ (int) {
            iterator prev = *this;
            ++(*this);
            return prev;
        }
        bool operator==(const iterator& rhs) const { return (_node == rhs._node); }
        bool operator!=(const iterator& rhs) const { return (_node != rhs._node); }
    private:
        next_t   _node;
        avl    * _avl;

        friend class avl::const_iterator;
    };
    class const_iterator {
    public:
        const_iterator(const avl * avl, next_t start) : _node(start), _avl(avl) { }
        const_iterator(const iterator &i) : _node(i._node), _avl(i._avl) {}
        Value & operator * ()  const { return _avl->get(_node); }
        Value * operator -> () const { return & _avl->get(_node); }
        iterator & operator ++ () {
            _node = _avl->getNextRight(_node);
            return *this;
        }
        iterator operator ++ (int) {
            iterator prev = *this;
            ++(*this);
            return prev;
        }
        bool operator==(const iterator& rhs) const { return (_node == rhs._node); }
        bool operator!=(const iterator& rhs) const { return (_node != rhs._node); }
    private:
        next_t   _node;
        avl    * _avl;

        friend class avl::const_iterator;
    };

public:
    avl(size_t reservedSpace);
    iterator begin()             { return iterator(this, getLeftMost(_root)); }
    iterator end()               { return iterator(this, Node::npos); }
    const_iterator begin() const { return const_iterator(this, getLeftMost(_root)); }
    const_iterator end()   const { return const_iterator(this, Node::npos); }
    size_t capacity()      const { return _nodes.capacity(); }
    size_t size()          const { return _nodes.size(); }
    bool empty()           const { return _root == npos; }
    template< typename AltKey, typename AltExtract, typename AltCompare >
    iterator find(const AltKey & key);
    iterator find(const Key & key);
    iterator find(const Key & key) { return iterator(this, internalFind(key)); }
    template< typename AltKey, typename AltExtract, typename AltCompare >
    const_iterator find(const AltKey & key) const;
    const_iterator find(const Key & key) const { return const_iterator(this, internalFind(key)); }
    void insert(const Value & node);
    void erase(const Key & key);
    void clear() { _nodes.clear(); }
    void reserve(size_t newReserve) { _nodes.reserve(newReserve); }
    void swap(avl & rhs);
    /**
     * Get an approximate number of memory consumed by hash set. Not including
     * any data K would store outside of sizeof(K) of course.
     */
    size_t getMemoryConsumption() const;

protected:
    /// These two methods are only for the ones that know what they are doing.
    /// valid input here are stuff returned from iterator.getInternalIndex.
    next_t insertInternal(const Value & node);
    Value & getByInternalIndex(size_t index)             { return _nodes[index].getValue(); }
    const Value & getByInternalIndex(size_t index) const { return _nodes[index].getValue(); }
    template <typename MoveHandler>
    void erase(MoveHandler & moveHandler, const Key & key);
private:
    next_t      _begin;
    next_t      _root;
    NodeStore   _nodes;
    Compare     _compare;
    KeyExtract  _keyExtractor;
    next_t internalFind(const Key & key) const;
    Value & get(size_t index)             { return _nodes[index].getValue(); }
    const Value & get(size_t index) const { return _nodes[index].getValue(); }
    next_t getRightMost(next_t n) {
        while(_nodes[n].hasRight()) {
            n = _nodes[n].getRight());
        }
        return n;
    }
    next_t getLeftMost(next_t n) {
        while(_nodes[n].hasLeft()) {
            n = _nodes[n].getleft());
        }
        return n;
    }
    next_t getNextRight(next_t n) const {
        if (_nodes[n].hasParent()) {
            next_t parent = _nodes[_node].getParent();
            if (_nodes[parent].getLeft() == _node) {
                return parent;
            } else {
                return getNextRight(parent);
            }
        } else if (_nodes[n].hasRight()) {
            return getLeftMost(_nodes[n].getRight());
        } else {
            return npos;
        }
    }
    next_t getNextLeft(next_t n) const {
        if (_nodes[n].hasParent()) {
            next_t parent = _nodes[_node].getParent();
            if (_nodes[parent].getRight() == _node) {
                return parent;
            } else {
                return getNextLeft(parent);
            }
        } else if (_nodes[n].hasLeft()) {
            return getRightMost(_nodes[n].getLeft());
        } else {
            return npos;
        }
    }
};

template< typename Key, typename Value, typename Compare, typename KeyExtract >
void avl<Key, Value, Compare, KeyExtract>::swap(avl & rhs)
{
    std::swap(_root, rhs._root);
    _nodes.swap(rhs._nodes);
    std::swap(_compare, rhs._compare);
    std::swap(_keyExtractor, rhs._keyExtractor);
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
avl<Key, Value, Compare, KeyExtract>::avl(size_t reservedSpace) :
    _root(npos),
    _nodes()
{
    if (reservedSpace > 0) {
        reserve(reservedSpace);
    }
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
typename avl<Key, Value, Compare, KeyExtract>::iterator
avl<Key, Value, Compare, KeyExtract>::internalFind(const Key & key)
{
    next_t found = npos; // Last node which is not less than key.

    for (next_t n(_begin); n != npos; ) {
        if (!_compare(_keyExtractor(_nodes[n]), key)) {
            found = n;
            n = getNextLeft(n);
        } else {
            n = getNextRight(n);
        }
    }
    return ((found != npos) && ! _compare(key, _keyExtractor(_nodes[found])))
           ? found
           : npos;
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
typename avl<Key, Value, Compare, KeyExtract>::next_t
avl<Key, Value, Compare, KeyExtract>::insert(const Value & node)
{
    next_t n = _begin;
    next_t e = npos;
    Key key(_keyExtractor(node);
    while (n != npos) {
        e = n;
        n = _compare(_keyExtractor(_nodes[n]), key)
            ? getNextLeft(n)
            : getNextRight(n);
    }
    return insert(n, e, node);
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
typename avl<Key, Value, Compare, KeyExtract>::next_t
avl<Key, Value, Compare, KeyExtract>::insert(next_t n, next_t e, const Value & value)
{
    bool insert_left = (n != npos) ||
                       (e == npos) ||
                       _compare(_keyExtractor(value), _keyExtractor(_nodes[e]));

    next_t newN = _nodes.size();
    Node node(value);
    _nodes.push_back(node);

    insert_and_rebalance(insert_left, newN, e, this->_M_impl._M_header);
    return iterator(newN);
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
void
avl<Key, Value, Compare, KeyExtract>::erase(const Key & key)
{
    next_t found = internalFind(key);
    if (found != npos) {
        // Link out
        erase_and_rebalance(found);
        // Swap with last
        std::swap(_nodes[found], _nodes.back());
        nodes.resize(nodes.size() - 1);
        // relink parent to last
        if (_nodes[found].hasParent()) {
            next_t parent = _nodes[found].getParent();
            if (_nodes[parent].getLeft() == old) {
                _nodes[parent].setLeft(found);
            } else {
                _nodes[parent].setRight(found);
            }
        }
    }
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
size_t
avl<Key, Value, Compare, KeyExtract>::getMemoryConsumption() const
{
    return sizeof(*this) + _nodes.capacity() * sizeof(Node);
}

#if 0
template< typename Key, typename Value, typename Compare, typename KeyExtract >
template< typename AltKey, typename AltExtract, typename AltCompare>
typename avl<Key, Value, Compare, KeyExtract>::const_iterator
avl<Key, Value, Compare, KeyExtract>::find(const AltKey & key) const
{
    if (_modulo > 0) {
        AltHash altHasher;
        next_t h = altHasher(key) % _modulo;
        if (_nodes[h].valid()) {
            next_t start(h);
            AltExtract altExtract;
            AltCompare altCompare;
            do {
                if (altCompare(altExtract(_keyExtractor(_nodes[h].getValue())), key)) {
                    return const_iterator(this, start, h);
                }
                h = _nodes[h].getNext();
            } while (h != Node::npos);
        }
    }
    return end();
}

template< typename Key, typename Value, typename Compare, typename KeyExtract >
template< typename AltKey, typename AltExtract, typename AltCompare>
typename avl<Key, Value, Hash, Compare, KeyExtract>::iterator
avl<Key, Value, Compare, KeyExtract>::find(const AltKey & key)
{
    if (_modulo > 0) {
        AltHash altHasher;
        next_t h = altHasher(key) % _modulo;
        if (_nodes[h].valid()) {
            next_t start(h);
            AltExtract altExtract;
            AltCompare altCompare;
            do {
                if (altCompare(altExtract(_keyExtractor(_nodes[h].getValue())), key)) {
                    return iterator(this, start, h);
                }
                h = _nodes[h].getNext();
            } while (h != Node::npos);
        }
    }
    return end();
}
#endif

}