aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/vm_forest.cpp
blob: 340be24ea6ca93bbd5f732d2041bef07a49d4e60 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "gbdt.h"
#include "vm_forest.h"
#include <vespa/eval/eval/basic_nodes.h>
#include <vespa/eval/eval/call_nodes.h>
#include <vespa/eval/eval/operator_nodes.h>

namespace vespalib {
namespace eval {
namespace gbdt {

namespace {

//-----------------------------------------------------------------------------

constexpr uint32_t LEAF     = 0; 
constexpr uint32_t LESS     = 1; 
constexpr uint32_t IN       = 2; 
constexpr uint32_t INVERTED = 3; 

// layout:
//
// <feature+types>: [feature ref|my type|left child type|right child type]
// bits:                      20       4               4                4
//
// LEAF:    [const]
// bits:         32
//
// LESS:    [<feature+types>][const][skip]
// bits                    32     32    32
//
// IN:      [<feature+types>][skip|set size](set size)X[const]
// bits                    32    24        8                64

// Note: We need to use double for set membership checks (IN) due to
// string hashing.

double read_double(const uint32_t *pos) {
    double value;
    memcpy(&value, pos, sizeof(value));
    return value;
}

const float *as_float_ptr(const uint32_t *pos) {
    return reinterpret_cast<const float*>(pos);
}

bool find_in(double value, const uint32_t *set, const uint32_t *end) {
    for (; set < end; set += 2) {
        if (value == read_double(set)) {
            return true;
        }
    }
    return false;
}

double less_only_find_leaf(const double *input, const uint32_t *pos, uint32_t node_type) {
    for (;;) {
        if (input[pos[0] >> 12] < *as_float_ptr(pos + 1)) {
            node_type = (pos[0] & 0xf0) >> 4;
            pos += 3;
        } else {
            node_type = (pos[0] & 0xf);
            pos += 3 + pos[2];
        }
        if (node_type == LEAF) {
            return *as_float_ptr(pos);
        }
    }
}

double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node_type) {
    for (;;) {
        if (node_type == LESS) {
            if (input[pos[0] >> 12] < *as_float_ptr(pos + 1)) {
                node_type = (pos[0] & 0xf0) >> 4;
                pos += 3;
            } else {
                node_type = (pos[0] & 0xf);
                pos += 3 + pos[2];
            }
            if (node_type == LEAF) {
                return *as_float_ptr(pos);
            }
        } else if (node_type == IN) {
            if (find_in(input[pos[0] >> 12], pos + 2,
                        pos + 2 + (2 * (pos[1] & 0xff))))
            {
                node_type = (pos[0] & 0xf0) >> 4;
                pos += 2 + (2 * (pos[1] & 0xff));
            } else {
                node_type = (pos[0] & 0xf);
                pos += (2 + (2 * (pos[1] & 0xff))) + (pos[1] >> 8);
            }
            if (node_type == LEAF) {
                return *as_float_ptr(pos);
            }
        } else {
            if (input[pos[0] >> 12] >= *as_float_ptr(pos + 1)) {
                node_type = (pos[0] & 0xf);
                pos += 3 + pos[2];
            } else {
                node_type = (pos[0] & 0xf0) >> 4;
                pos += 3;
            }
            if (node_type == LEAF) {
                return *as_float_ptr(pos);
            }
        }
    }
}

//-----------------------------------------------------------------------------

void encode_large_const(double value, std::vector<uint32_t> &model_out) {
    uint32_t buf[2];
    static_assert(sizeof(buf) == sizeof(value));
    memcpy(buf, &value, sizeof(value));
    model_out.push_back(buf[0]);
    model_out.push_back(buf[1]);
}

void encode_const(float value, std::vector<uint32_t> &model_out) {
    uint32_t buf;
    static_assert(sizeof(buf) == sizeof(value));
    memcpy(&buf, &value, sizeof(value));
    model_out.push_back(buf);
}

uint32_t encode_node(const nodes::Node &node_in, std::vector<uint32_t> &model_out);

void encode_less(const nodes::Less &less,
                 const nodes::Node &left_child, const nodes::Node &right_child,
                 std::vector<uint32_t> &model_out)
{
    size_t meta_idx = model_out.size();
    auto symbol = nodes::as<nodes::Symbol>(less.lhs());
    assert(symbol);
    model_out.push_back(uint32_t(symbol->id()) << 12);
    assert(less.rhs().is_const_double());
    encode_const(less.rhs().get_const_double_value(), model_out);
    size_t skip_idx = model_out.size();
    model_out.push_back(0); // left child size placeholder
    uint32_t left_type = encode_node(left_child, model_out);
    model_out[skip_idx] = (model_out.size() - (skip_idx + 1));
    uint32_t right_type = encode_node(right_child, model_out);
    model_out[meta_idx] |= ((LESS << 8) | (left_type << 4) | right_type);
}

void encode_in(const nodes::In &in,
               const nodes::Node &left_child, const nodes::Node &right_child,
               std::vector<uint32_t> &model_out)
{
    size_t meta_idx = model_out.size();
    auto symbol = nodes::as<nodes::Symbol>(in.child());
    assert(symbol);
    model_out.push_back(uint32_t(symbol->id()) << 12);
    size_t set_size_idx = model_out.size();
    model_out.push_back(in.num_entries());
    for (size_t i = 0; i < in.num_entries(); ++i) {
        encode_large_const(in.get_entry(i).get_const_double_value(), model_out);
    }
    size_t left_idx = model_out.size();
    uint32_t left_type = encode_node(left_child, model_out);
    model_out[set_size_idx] |= (model_out.size() - left_idx) << 8;
    uint32_t right_type = encode_node(right_child, model_out);
    model_out[meta_idx] |= ((IN << 8) | (left_type << 4) | right_type);
}

void encode_inverted(const nodes::Not &inverted,
                     const nodes::Node &left_child, const nodes::Node &right_child,
                     std::vector<uint32_t> &model_out)
{
    size_t meta_idx = model_out.size();
    auto ge = nodes::as<nodes::GreaterEqual>(inverted.child());
    assert(ge);
    auto symbol = nodes::as<nodes::Symbol>(ge->lhs());
    assert(symbol);
    model_out.push_back(uint32_t(symbol->id()) << 12);
    assert(ge->rhs().is_const_double());
    encode_const(ge->rhs().get_const_double_value(), model_out);
    size_t skip_idx = model_out.size();
    model_out.push_back(0); // left child size placeholder
    uint32_t left_type = encode_node(left_child, model_out);
    model_out[skip_idx] = (model_out.size() - (skip_idx + 1));
    uint32_t right_type = encode_node(right_child, model_out);
    model_out[meta_idx] |= ((INVERTED << 8) | (left_type << 4) | right_type);
}

uint32_t encode_node(const nodes::Node &node_in, std::vector<uint32_t> &model_out) {
    auto if_node = nodes::as<nodes::If>(node_in);
    if (if_node) {
        auto less = nodes::as<nodes::Less>(if_node->cond());
        auto in = nodes::as<nodes::In>(if_node->cond());
        auto inverted = nodes::as<nodes::Not>(if_node->cond());
        if (less) {
            encode_less(*less, if_node->true_expr(), if_node->false_expr(), model_out);
            return LESS;
        } else if (in) {
            encode_in(*in, if_node->true_expr(), if_node->false_expr(), model_out);
            return IN;
        } else {
            assert(inverted);
            encode_inverted(*inverted, if_node->true_expr(), if_node->false_expr(), model_out);
            return INVERTED;
        }
    } else {
        assert(node_in.is_const_double());
        encode_const(node_in.get_const_double_value(), model_out);
        return LEAF;
    }
}

void encode_tree(const nodes::Node &root_in, std::vector<uint32_t> &model_out) {
    size_t size_idx = model_out.size();
    model_out.push_back(0); // tree size placeholder
    encode_node(root_in, model_out);
    model_out[size_idx] = (model_out.size() - (size_idx + 1));
}

//-----------------------------------------------------------------------------

Optimize::Result optimize(const std::vector<const nodes::Node *> &trees,
                          Forest::eval_function eval)
{
    std::vector<uint32_t> model;
    for (const nodes::Node *tree: trees) {
        encode_tree(*tree, model);
    }
    return Optimize::Result(Forest::UP(new VMForest(std::move(model))), eval);
}

//-----------------------------------------------------------------------------

} // namespace vespalib::eval::gbdt::<unnamed>

//-----------------------------------------------------------------------------

Optimize::Result
VMForest::less_only_optimize(const ForestStats &stats,
                             const std::vector<const nodes::Node *> &trees)
{
    if ((stats.total_in_checks > 0) || (stats.total_inverted_checks > 0)) {
        return Optimize::Result();
    }
    return optimize(trees, less_only_eval);
}

double
VMForest::less_only_eval(const Forest *forest, const double *input)
{
    const VMForest &self = *((const VMForest *)forest);
    const uint32_t *pos = &self._model[0];
    const uint32_t *end = pos + self._model.size();
    double sum = 0.0;
    while (pos < end) {
        uint32_t tree_size = *pos++;
        sum += less_only_find_leaf(input, pos, (*pos & 0xf00) >> 8);
        pos += tree_size;
    }
    return sum;
}

Optimize::Result
VMForest::general_optimize(const ForestStats &stats,
                           const std::vector<const nodes::Node *> &trees)
{
    if (stats.max_set_size > 255) {
        return Optimize::Result();
    }
    return optimize(trees, general_eval);
}

double
VMForest::general_eval(const Forest *forest, const double *input)
{
    const VMForest &self = *((const VMForest *)forest);
    const uint32_t *pos = &self._model[0];
    const uint32_t *end = pos + self._model.size();
    double sum = 0.0;
    while (pos < end) {
        uint32_t tree_size = *pos++;
        sum += general_find_leaf(input, pos, (*pos & 0xf00) >> 8);
        pos += tree_size;
    }
    return sum;
}

Optimize::Chain VMForest::optimize_chain({less_only_optimize, general_optimize});

//-----------------------------------------------------------------------------

} // namespace vespalib::eval::gbdt
} // namespace vespalib::eval
} // namespace vespalib