diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-01-23 12:14:40 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-01-23 12:14:40 +0000 |
commit | 145659f1d677face587b710726285df872a319c0 (patch) | |
tree | 074eafbf9d3b9ee030ff2ec584667b0386f37618 /eval/src/tests/eval/gbdt/model.cpp | |
parent | 31690a1baa64d046d7ba25510b4570aa20792134 (diff) |
move code
Diffstat (limited to 'eval/src/tests/eval/gbdt/model.cpp')
-rw-r--r-- | eval/src/tests/eval/gbdt/model.cpp | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp new file mode 100644 index 00000000000..e125d9e77d2 --- /dev/null +++ b/eval/src/tests/eval/gbdt/model.cpp @@ -0,0 +1,99 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <random> +#include <vespa/vespalib/eval/function.h> + +using vespalib::make_string; +using vespalib::eval::Function; + +//----------------------------------------------------------------------------- + +class Model +{ +private: + std::mt19937 _gen; + size_t _less_percent; + + size_t get_int(size_t min, size_t max) { + std::uniform_int_distribution<size_t> dist(min, max); + return dist(_gen); + } + + double get_real(double min, double max) { + std::uniform_real_distribution<double> dist(min, max); + return dist(_gen); + } + + std::string make_feature_name() { + size_t max_feature = 2; + while ((max_feature < 1024) && (get_int(0, 99) < 50)) { + max_feature *= 2; + } + return make_string("feature_%zu", get_int(1, max_feature)); + } + + std::string make_cond() { + if (get_int(1,100) > _less_percent) { + return make_string("(%s in [%g,%g,%g])", + make_feature_name().c_str(), + get_int(0, 4) / 4.0, + get_int(0, 4) / 4.0, + get_int(0, 4) / 4.0); + } else { + return make_string("(%s<%g)", + make_feature_name().c_str(), + get_real(0.0, 1.0)); + } + } + +public: + explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80) {} + + Model &less_percent(size_t value) { + _less_percent = value; + return *this; + } + + std::string make_tree(size_t size) { + assert(size > 0); + if (size == 1) { + return make_string("%g", get_real(0.0, 1.0)); + } + size_t pivot = get_int(1, size - 1); + return make_string("if(%s,%s,%s)", + make_cond().c_str(), + make_tree(pivot).c_str(), + make_tree(size - pivot).c_str()); + } + + std::string make_forest(size_t num_trees, size_t tree_sizes) { + assert(num_trees > 0); + vespalib::string forest = make_tree(tree_sizes); + for (size_t i = 1; i < num_trees; ++i) { + forest.append("+"); + forest.append(make_tree(tree_sizes)); + } + return forest; + } +}; + +//----------------------------------------------------------------------------- + +struct ForestParams { + size_t model_seed; + size_t less_percent; + size_t tree_size; + ForestParams(size_t model_seed_in, size_t less_percent_in, size_t tree_size_in) + : model_seed(model_seed_in), less_percent(less_percent_in), tree_size(tree_size_in) {} +}; + +//----------------------------------------------------------------------------- + +Function make_forest(const ForestParams ¶ms, size_t num_trees) { + return Function::parse(Model(params.model_seed) + .less_percent(params.less_percent) + .make_forest(num_trees, params.tree_size)); +} + +//----------------------------------------------------------------------------- |