aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/fast_addr_map.h
blob: adc27f727693787058c5ae59a4b551ce7ed9270a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include "memory_usage_stuff.h"
#include <vespa/vespalib/util/arrayref.h>
#include <vespa/vespalib/util/string_id.h>
#include <vespa/vespalib/stllike/identity.h>
#include <vespa/vespalib/stllike/hashtable.h>

namespace vespalib::eval {

/**
 * A wrapper around vespalib::hashtable, using it to map a list of
 * labels (a sparse address) to an integer value (dense subspace
 * index). Labels are represented by string enum values stored and
 * handled outside this class.
 **/
class FastAddrMap
{
public:
    // label extracting functions
    static constexpr string_id self(string_id label) noexcept { return label; }
    static constexpr string_id self(const string_id *label) noexcept { return *label; }

    // label hashing functions
    static constexpr uint32_t hash_label(string_id label) noexcept { return label.value(); }
    static constexpr uint32_t hash_label(const string_id *label) noexcept { return label->value(); }
    static constexpr uint32_t combine_label_hash(uint32_t full_hash, uint32_t next_hash) noexcept {
        return ((full_hash * 31) + next_hash);
    }
    template <typename T>
    static constexpr uint32_t hash_labels(ConstArrayRef<T> addr) {
        uint32_t hash = 0;
        for (const T &label: addr) {
            hash = combine_label_hash(hash, hash_label(label));
        }
        return hash;
    }

    // typed uint32_t index used to identify sparse address/dense subspace
    struct Tag {
        uint32_t idx;
        static constexpr uint32_t npos() noexcept { return uint32_t(-1); }
        static constexpr Tag make_invalid() noexcept { return Tag{npos()}; }
        constexpr bool valid() const noexcept { return (idx != npos()); }
    };

    // sparse hash set entry
    struct Entry {
        Tag tag;
        uint32_t hash;
    };

    // alternative key(s) used for lookup in sparse hash set
    template <typename T> struct AltKey {
        ConstArrayRef<T> key;
        uint32_t hash;
    };

    // view able to convert tags into sparse addresses
    struct LabelView {
        size_t addr_size;
        const StringIdVector &labels;
        LabelView(size_t num_mapped_dims, const StringIdVector &labels_in) noexcept
            : addr_size(num_mapped_dims), labels(labels_in) {}
        ConstArrayRef<string_id> get_addr(size_t idx) const noexcept {
            return {labels.data() + (idx * addr_size), addr_size};
        }
    };

    // hashing functor for sparse hash set
    struct Hash {
        template <typename T>
        constexpr uint32_t operator()(const AltKey<T> &key) const noexcept { return key.hash; }
        constexpr uint32_t operator()(const Entry &entry) const noexcept { return entry.hash; }
        constexpr uint32_t operator()(string_id label) const noexcept { return label.value(); }
    };

    // equality functor for sparse hash set
    struct Equal {
        const LabelView &label_view;
        Equal(const LabelView &label_view_in) noexcept : label_view(label_view_in) {}
        template <typename T>
        bool operator()(const Entry &a, const AltKey<T> &b) const noexcept {
            if (a.hash != b.hash) {
                return false;
            }
            auto a_key = label_view.get_addr(a.tag.idx);
            for (size_t i = 0; i < a_key.size(); ++i) {
                if (a_key[i] != self(b.key[i])) {
                    return false;
                }
            }
            return true;
        }
        bool operator()(const Entry &a, string_id b) const { return (a.hash == b.value()); }
    };

    using HashType = hashtable<Entry, Entry, Hash, Equal, Identity, hashtable_base::and_modulator>;

private:
    LabelView _labels;
    HashType _map;

public:
    FastAddrMap(size_t num_mapped_dims, const StringIdVector &labels_in, size_t expected_subspaces);
    ~FastAddrMap();
    FastAddrMap(const FastAddrMap &) = delete;
    FastAddrMap &operator=(const FastAddrMap &) = delete;
    FastAddrMap(FastAddrMap &&) = delete;
    FastAddrMap &operator=(FastAddrMap &&) = delete;
    static constexpr size_t npos() noexcept { return -1; }
    ConstArrayRef<string_id> get_addr(size_t idx) const noexcept { return _labels.get_addr(idx); }
    size_t size() const noexcept { return _map.size(); }
    constexpr size_t addr_size() const noexcept { return _labels.addr_size; }
    const StringIdVector &labels() const noexcept { return _labels.labels; }
    template <typename T>
    size_t lookup(ConstArrayRef<T> addr, uint32_t hash) const noexcept {
        // assert(addr_size() == addr.size());
        AltKey<T> key{addr, hash};
        auto pos = _map.find(key);
        return (pos == _map.end()) ? npos() : pos->tag.idx;
    }
    size_t lookup_singledim(string_id addr) const noexcept {
        // assert(addr_size() == 1);
        auto pos = _map.find(addr);
        return (pos == _map.end()) ? npos() : pos->tag.idx;
    }
    template <typename T>
    size_t lookup(ConstArrayRef<T> addr) const noexcept {
        return (addr.size() == 1)
            ? lookup_singledim(self(addr[0]))
            : lookup(addr, hash_labels(addr));
    }
    void add_mapping(uint32_t hash) {
        uint32_t idx = _map.size();
        _map.force_insert(Entry{{idx}, hash});
    }
    template <typename F>
    void each_map_entry(F &&f) const {
        _map.for_each([&](const auto &entry)
                      {
                          f(entry.tag.idx, entry.hash);
                      });
    }
    MemoryUsage estimate_extra_memory_usage() const {
        MemoryUsage extra_usage;
        size_t map_self_size = sizeof(_map);
        size_t map_used = _map.getMemoryUsed();
        size_t map_allocated = _map.getMemoryConsumption();
        // avoid double-counting the map itself
        map_used = std::min(map_used, map_used - map_self_size);
        map_allocated = std::min(map_allocated, map_allocated - map_self_size);
        extra_usage.incUsedBytes(map_used);
        extra_usage.incAllocatedBytes(map_allocated);
        return extra_usage;
    }
};

}