aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/eval/gbdt/model.cpp
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-09-06 11:22:24 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-09-06 11:22:24 +0000
commit835438109b27210193f077132b40b7f46bceb218 (patch)
treea3866d1a8277d9d2f58a5d70ad14ef7cab24e30a /eval/src/tests/eval/gbdt/model.cpp
parent51565a5d4fa530ba0313ed3d8e62d51cd39906c1 (diff)
detect if inversion as gbdt model
Diffstat (limited to 'eval/src/tests/eval/gbdt/model.cpp')
-rw-r--r--eval/src/tests/eval/gbdt/model.cpp20
1 files changed, 16 insertions, 4 deletions
diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp
index 112f058fa2c..e531b327e89 100644
--- a/eval/src/tests/eval/gbdt/model.cpp
+++ b/eval/src/tests/eval/gbdt/model.cpp
@@ -14,6 +14,7 @@ class Model
private:
std::mt19937 _gen;
size_t _less_percent;
+ size_t _invert_percent;
size_t get_int(size_t min, size_t max) {
std::uniform_int_distribution<size_t> dist(min, max);
@@ -41,20 +42,31 @@ private:
get_int(0, 4) / 4.0,
get_int(0, 4) / 4.0);
} else {
- return make_string("(%s<%g)",
- make_feature_name().c_str(),
- get_real(0.0, 1.0));
+ if (get_int(1,100) > _invert_percent) {
+ return make_string("(%s<%g)",
+ make_feature_name().c_str(),
+ get_real(0.0, 1.0));
+ } else {
+ return make_string("(!(%s>=%g))",
+ make_feature_name().c_str(),
+ get_real(0.0, 1.0));
+ }
}
}
public:
- explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80) {}
+ explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80), _invert_percent(0) {}
Model &less_percent(size_t value) {
_less_percent = value;
return *this;
}
+ Model &invert_percent(size_t value) {
+ _invert_percent = value;
+ return *this;
+ }
+
std::string make_tree(size_t size) {
assert(size > 0);
if (size == 1) {