aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/gbdt.h
blob: 7da44955386587b57646fd77dca91ac4fd545484 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vector>
#include <memory>

namespace vespalib {
namespace eval {

namespace nodes { struct Node; }

namespace gbdt {

//-----------------------------------------------------------------------------

/**
 * Function used to map out individual GBDT trees from a GBDT forest.
 **/
std::vector<const nodes::Node *> extract_trees(const nodes::Node &node);

/**
 * Statistics for a single GBDT tree.
 **/
struct TreeStats {
    size_t size;
    size_t num_less_checks;     // foo < 2.5
    size_t num_in_checks;       // foo in [1,2,3]
    size_t num_inverted_checks; // !(foo >= 2.5)
    size_t num_tuned_checks;
    size_t max_set_size;
    double expected_path_length;
    double average_path_length;
    size_t num_params;
    explicit TreeStats(const nodes::Node &tree);
private:
    double traverse(const nodes::Node &tree, size_t depth, size_t &sum_path);
};

/**
 * Statistics for a GBDT forest.
 **/
struct ForestStats {
    struct TreeSize {
        size_t size;
        size_t count;
    };
    size_t num_trees;
    size_t total_size;
    std::vector<TreeSize> tree_sizes;
    size_t total_less_checks;
    size_t total_in_checks;
    size_t total_inverted_checks;
    size_t total_tuned_checks;
    size_t max_set_size;
    double total_expected_path_length;
    double total_average_path_length;
    size_t num_params;
    explicit ForestStats(const std::vector<const nodes::Node *> &trees);
};

//-----------------------------------------------------------------------------

/**
 * Check if the given sub-expression contains GBDT. This function
 * returns true if the number of tree/forest nodes exceeds the given
 * limit.
 **/
bool contains_gbdt(const nodes::Node &node, size_t limit);

//-----------------------------------------------------------------------------

/**
 * A Forest object represents deletable custom prepared state that may
 * be used to evaluate a GBDT forest from within LLVM generated
 * machine code. It is very important that the evaluation function
 * used is passed exactly the subclass of Forest it expects. This is
 * why Optimize::Result bundles together both the prepared state
 * (Forest object) and the evaluation function reference; they are
 * chosen at the same time at the same place.
 **/
struct Forest {
    using UP = std::unique_ptr<Forest>;
    using eval_function = double (*)(const Forest *self, const double *args);
    virtual ~Forest() {}
};

/**
 * Definitions and helper functions related to custom GBDT forest
 * optimization. The optimization chain named 'best' is used by
 * default. The one named 'none' results in no special handling for
 * GBDT forests.
 **/
struct Optimize {
    struct Result {
        Forest::UP forest;
        Forest::eval_function eval;
        Result() : forest(nullptr), eval(nullptr) {}
        Result(Forest::UP &&forest_in, Forest::eval_function eval_in)
            : forest(std::move(forest_in)), eval(eval_in) {}
        Result(Result &&rhs) : forest(std::move(rhs.forest)), eval(rhs.eval) {}
        bool valid() const { return (forest.get() != nullptr); }
    };
    using optimize_function = Result (*)(const ForestStats &stats,
                                         const std::vector<const nodes::Node *> &trees);
    using Chain = std::vector<optimize_function>;
    static Result select_best(const ForestStats &stats,
                              const std::vector<const nodes::Node *> &trees);
    static Chain best;
    static Chain none;
    static Result apply_chain(const Chain &chain,
                              const ForestStats &stats,
                              const std::vector<const nodes::Node *> &trees) {
        for (optimize_function optimize: chain) {
            Result result = optimize(stats, trees);
            if (result.valid()) {
                return result;
            }
        }
        return Result();
    }
    // Optimize() = delete;
};

//-----------------------------------------------------------------------------

} // namespace vespalib::eval::gbdt
} // namespace vespalib::eval
} // namespace vespalib