summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/eval/gbdt/model.cpp
blob: 2659c07b6013b244355fe39bae50a4f38e2c32c7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once

#include <random>
#include <vespa/eval/eval/function.h>
#include <vespa/vespalib/util/stringfmt.h>

using vespalib::make_string;
using vespalib::eval::Function;

//-----------------------------------------------------------------------------

class Model
{
private:
    std::mt19937 _gen;
    size_t _max_features;
    size_t _less_percent;
    size_t _invert_percent;

    size_t get_int(size_t min, size_t max) {
        std::uniform_int_distribution<size_t> dist(min, max);
        return dist(_gen);
    }

    double get_real() {
        std::uniform_real_distribution<double> dist(0.0, 1.0);
        double result = dist(_gen);
        // avoid different decisions based on using float vs. double split values
        while (float(result) == 0.5) {
            result = dist(_gen);
        }
        return result;
    }

    std::string make_feature_name() {
        size_t max_feature = 7;
        while ((max_feature < _max_features) && (get_int(0, 99) < 55)) {
            max_feature = std::min(max_feature * 2, _max_features);
        }
        return make_string("feature_%zu", get_int(1, max_feature));
    }

    std::string make_cond() {
        if (get_int(1,100) > _less_percent) {
            return make_string("(%s in [%g,%g,%g])",
                               make_feature_name().c_str(),
                               get_int(0, 4) / 4.0,
                               get_int(0, 4) / 4.0,
                               get_int(0, 4) / 4.0);
        } else {
            if (get_int(1,100) > _invert_percent) {
                return make_string("(%s<%g)",
                                   make_feature_name().c_str(),
                                   get_real());
            } else {
                return make_string("(!(%s>=%g))",
                                   make_feature_name().c_str(),
                                   get_real());
            }
        }
    }

public:
    explicit Model(size_t seed = 5489u) : _gen(seed), _max_features(1024), _less_percent(80), _invert_percent(0) {}

    Model &max_features(size_t value) {
        _max_features = value;
        return *this;
    }

    Model &less_percent(size_t value) {
        _less_percent = value;
        return *this;
    }

    Model &invert_percent(size_t value) {
        _invert_percent = value;
        return *this;
    }

    std::string make_tree(size_t size) {
        assert(size > 0);
        if (size == 1) {
            return make_string("%g", get_real());
        }
        size_t pivot = get_int(1, size - 1);
        return make_string("if(%s,%s,%s)",
                           make_cond().c_str(),
                           make_tree(pivot).c_str(),
                           make_tree(size - pivot).c_str());
    }

    std::string make_forest(size_t num_trees, size_t tree_sizes) {
        assert(num_trees > 0);
        vespalib::string forest = make_tree(tree_sizes);
        for (size_t i = 1; i < num_trees; ++i) {
            forest.append("+");
            forest.append(make_tree(tree_sizes));
        }
        return forest;
    }
};

//-----------------------------------------------------------------------------

struct ForestParams {
    size_t model_seed;
    size_t less_percent;
    size_t tree_size;
    ForestParams(size_t model_seed_in, size_t less_percent_in, size_t tree_size_in)
        : model_seed(model_seed_in), less_percent(less_percent_in), tree_size(tree_size_in) {}
};

//-----------------------------------------------------------------------------

auto make_forest(const ForestParams &params, size_t num_trees) {
    return Function::parse(Model(params.model_seed)
                           .less_percent(params.less_percent)
                           .make_forest(num_trees, params.tree_size));
}

//-----------------------------------------------------------------------------