diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-09-06 11:22:24 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-09-06 11:22:24 +0000 |
commit | 835438109b27210193f077132b40b7f46bceb218 (patch) | |
tree | a3866d1a8277d9d2f58a5d70ad14ef7cab24e30a /eval/src/tests/eval/gbdt/model.cpp | |
parent | 51565a5d4fa530ba0313ed3d8e62d51cd39906c1 (diff) |
detect if inversion as gbdt model
Diffstat (limited to 'eval/src/tests/eval/gbdt/model.cpp')
-rw-r--r-- | eval/src/tests/eval/gbdt/model.cpp | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp index 112f058fa2c..e531b327e89 100644 --- a/eval/src/tests/eval/gbdt/model.cpp +++ b/eval/src/tests/eval/gbdt/model.cpp @@ -14,6 +14,7 @@ class Model private: std::mt19937 _gen; size_t _less_percent; + size_t _invert_percent; size_t get_int(size_t min, size_t max) { std::uniform_int_distribution<size_t> dist(min, max); @@ -41,20 +42,31 @@ private: get_int(0, 4) / 4.0, get_int(0, 4) / 4.0); } else { - return make_string("(%s<%g)", - make_feature_name().c_str(), - get_real(0.0, 1.0)); + if (get_int(1,100) > _invert_percent) { + return make_string("(%s<%g)", + make_feature_name().c_str(), + get_real(0.0, 1.0)); + } else { + return make_string("(!(%s>=%g))", + make_feature_name().c_str(), + get_real(0.0, 1.0)); + } } } public: - explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80) {} + explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80), _invert_percent(0) {} Model &less_percent(size_t value) { _less_percent = value; return *this; } + Model &invert_percent(size_t value) { + _invert_percent = value; + return *this; + } + std::string make_tree(size_t size) { assert(size > 0); if (size == 1) { |