aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-09-06 11:22:24 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-09-06 11:22:24 +0000
commit835438109b27210193f077132b40b7f46bceb218 (patch)
treea3866d1a8277d9d2f58a5d70ad14ef7cab24e30a
parent51565a5d4fa530ba0313ed3d8e62d51cd39906c1 (diff)
detect if inversion as gbdt model
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_test.cpp62
-rw-r--r--eval/src/tests/eval/gbdt/model.cpp20
-rw-r--r--eval/src/vespa/eval/eval/basic_nodes.cpp5
-rw-r--r--eval/src/vespa/eval/eval/gbdt.cpp15
-rw-r--r--eval/src/vespa/eval/eval/gbdt.h6
-rw-r--r--eval/src/vespa/eval/eval/vm_forest.cpp50
-rw-r--r--searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp17
7 files changed, 132 insertions, 43 deletions
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp
index 9cf5c31f76b..865b01b861b 100644
--- a/eval/src/tests/eval/gbdt/gbdt_test.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp
@@ -37,13 +37,15 @@ TEST("require that tree stats can be calculated") {
EXPECT_EQUAL(4u, stats1.size);
EXPECT_EQUAL(1u, stats1.num_less_checks);
EXPECT_EQUAL(2u, stats1.num_in_checks);
+ EXPECT_EQUAL(0u, stats1.num_inverted_checks);
EXPECT_EQUAL(3u, stats1.max_set_size);
- TreeStats stats2(Function::parse("if((d in [1]),10.0,if((e<1),20.0,30.0))").root());
+ TreeStats stats2(Function::parse("if((d in [1]),10.0,if(!(e>=1),20.0,30.0))").root());
EXPECT_EQUAL(2u, stats2.num_params);
EXPECT_EQUAL(3u, stats2.size);
- EXPECT_EQUAL(1u, stats2.num_less_checks);
+ EXPECT_EQUAL(0u, stats2.num_less_checks);
EXPECT_EQUAL(1u, stats2.num_in_checks);
+ EXPECT_EQUAL(1u, stats2.num_inverted_checks);
EXPECT_EQUAL(1u, stats2.max_set_size);
}
@@ -63,8 +65,8 @@ TEST("require that trees can be extracted from forest") {
TEST("require that forest stats can be calculated") {
Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))");
+ "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+"
+ "if((a<1),10.0,if(!(e>=1),20.0,30.0))");
std::vector<const Node *> trees = extract_trees(function.root());
ForestStats stats(trees);
EXPECT_EQUAL(5u, stats.num_params);
@@ -75,8 +77,9 @@ TEST("require that forest stats can be calculated") {
EXPECT_EQUAL(2u, stats.tree_sizes[0].count);
EXPECT_EQUAL(4u, stats.tree_sizes[1].size);
EXPECT_EQUAL(1u, stats.tree_sizes[1].count);
- EXPECT_EQUAL(3u, stats.total_less_checks);
- EXPECT_EQUAL(4u, stats.total_in_checks);
+ EXPECT_EQUAL(2u, stats.total_less_checks);
+ EXPECT_EQUAL(3u, stats.total_in_checks);
+ EXPECT_EQUAL(2u, stats.total_inverted_checks);
EXPECT_EQUAL(3u, stats.max_set_size);
}
@@ -261,9 +264,18 @@ TEST("require that models with in checks are rejected by less only vm optimizer"
EXPECT_TRUE(!Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
}
+TEST("require that models with inverted checks are rejected by less only vm optimizer") {
+ Function function = Function::parse(Model().less_percent(100).make_forest(300, 30));
+ auto trees = extract_trees(function.root());
+ ForestStats stats(trees);
+ EXPECT_TRUE(Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
+ stats.total_inverted_checks = 1;
+ EXPECT_TRUE(!Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
+}
+
TEST("require that general VM tree optimizer works") {
Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ "if((d in [1]),10.0,if(!(e>=1),if((f<1),20.0,30.0),40.0))");
CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain);
EXPECT_EQUAL(1u, compiled_function.get_forests().size());
auto f = compiled_function.get_function();
@@ -301,21 +313,23 @@ TEST("require that forests evaluate to approximately the same for all evaluation
for (size_t tree_size: std::vector<size_t>({20})) {
for (size_t num_trees: std::vector<size_t>({10, 60})) {
for (size_t less_percent: std::vector<size_t>({100, 80})) {
- vespalib::string expression = Model().less_percent(less_percent).make_forest(num_trees, tree_size);
- Function function = Function::parse(expression);
- CompiledFunction none(function, pass_params, Optimize::none);
- CompiledFunction deinline(function, pass_params, DeinlineForest::optimize_chain);
- CompiledFunction vm_forest(function, pass_params, VMForest::optimize_chain);
- EXPECT_EQUAL(0u, none.get_forests().size());
- ASSERT_EQUAL(1u, deinline.get_forests().size());
- EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr);
- ASSERT_EQUAL(1u, vm_forest.get_forests().size());
- EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr);
- std::vector<double> inputs(function.num_params(), 0.5);
- double expected = eval_double(function, inputs);
- EXPECT_APPROX(expected, eval_compiled(none, inputs), 1e-6);
- EXPECT_APPROX(expected, eval_compiled(deinline, inputs), 1e-6);
- EXPECT_APPROX(expected, eval_compiled(vm_forest, inputs), 1e-6);
+ for (size_t invert_percent: std::vector<size_t>({0, 50})) {
+ vespalib::string expression = Model().less_percent(less_percent).invert_percent(invert_percent).make_forest(num_trees, tree_size);
+ Function function = Function::parse(expression);
+ CompiledFunction none(function, pass_params, Optimize::none);
+ CompiledFunction deinline(function, pass_params, DeinlineForest::optimize_chain);
+ CompiledFunction vm_forest(function, pass_params, VMForest::optimize_chain);
+ EXPECT_EQUAL(0u, none.get_forests().size());
+ ASSERT_EQUAL(1u, deinline.get_forests().size());
+ EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr);
+ ASSERT_EQUAL(1u, vm_forest.get_forests().size());
+ EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr);
+ std::vector<double> inputs(function.num_params(), 0.5);
+ double expected = eval_double(function, inputs);
+ EXPECT_EQUAL(expected, eval_compiled(none, inputs));
+ EXPECT_EQUAL(expected, eval_compiled(deinline, inputs));
+ EXPECT_EQUAL(expected, eval_compiled(vm_forest, inputs));
+ }
}
}
}
@@ -326,8 +340,8 @@ TEST("require that forests evaluate to approximately the same for all evaluation
TEST("require that GDBT expressions can be detected") {
Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in [1]),2.0,3.0),4.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))+"
- "if((d in [1]),10.0,if((e<1),20.0,30.0))");
+ "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))+"
+ "if((d in [1]),10.0,if(!(e>=1),20.0,30.0))");
EXPECT_TRUE(contains_gbdt(function.root(), 9));
EXPECT_TRUE(!contains_gbdt(function.root(), 10));
}
diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp
index 112f058fa2c..e531b327e89 100644
--- a/eval/src/tests/eval/gbdt/model.cpp
+++ b/eval/src/tests/eval/gbdt/model.cpp
@@ -14,6 +14,7 @@ class Model
private:
std::mt19937 _gen;
size_t _less_percent;
+ size_t _invert_percent;
size_t get_int(size_t min, size_t max) {
std::uniform_int_distribution<size_t> dist(min, max);
@@ -41,20 +42,31 @@ private:
get_int(0, 4) / 4.0,
get_int(0, 4) / 4.0);
} else {
- return make_string("(%s<%g)",
- make_feature_name().c_str(),
- get_real(0.0, 1.0));
+ if (get_int(1,100) > _invert_percent) {
+ return make_string("(%s<%g)",
+ make_feature_name().c_str(),
+ get_real(0.0, 1.0));
+ } else {
+ return make_string("(!(%s>=%g))",
+ make_feature_name().c_str(),
+ get_real(0.0, 1.0));
+ }
}
}
public:
- explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80) {}
+ explicit Model(size_t seed = 5489u) : _gen(seed), _less_percent(80), _invert_percent(0) {}
Model &less_percent(size_t value) {
_less_percent = value;
return *this;
}
+ Model &invert_percent(size_t value) {
+ _invert_percent = value;
+ return *this;
+ }
+
std::string make_tree(size_t size) {
assert(size > 0);
if (size == 1) {
diff --git a/eval/src/vespa/eval/eval/basic_nodes.cpp b/eval/src/vespa/eval/eval/basic_nodes.cpp
index 6138f9ac073..473deca1117 100644
--- a/eval/src/vespa/eval/eval/basic_nodes.cpp
+++ b/eval/src/vespa/eval/eval/basic_nodes.cpp
@@ -118,6 +118,7 @@ If::If(Node_UP cond_in, Node_UP true_expr_in, Node_UP false_expr_in, double p_tr
{
auto less = as<Less>(cond());
auto in = as<In>(cond());
+ auto inverted = as<Not>(cond());
bool true_is_subtree = (true_expr().is_tree() || true_expr().is_const());
bool false_is_subtree = (false_expr().is_tree() || false_expr().is_const());
if (true_is_subtree && false_is_subtree) {
@@ -125,6 +126,10 @@ If::If(Node_UP cond_in, Node_UP true_expr_in, Node_UP false_expr_in, double p_tr
_is_tree = (less->lhs().is_param() && less->rhs().is_const());
} else if (in) {
_is_tree = in->child().is_param();
+ } else if (inverted) {
+ if (auto ge = as<GreaterEqual>(inverted->child())) {
+ _is_tree = (ge->lhs().is_param() && ge->rhs().is_const());
+ }
}
}
}
diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp
index 0edc42070ba..45a737996bd 100644
--- a/eval/src/vespa/eval/eval/gbdt.cpp
+++ b/eval/src/vespa/eval/eval/gbdt.cpp
@@ -41,6 +41,7 @@ TreeStats::TreeStats(const nodes::Node &tree)
: size(0),
num_less_checks(0),
num_in_checks(0),
+ num_inverted_checks(0),
num_tuned_checks(0),
max_set_size(0),
expected_path_length(0.0),
@@ -63,18 +64,26 @@ TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) {
double false_path = traverse(if_node->false_expr(), depth + 1, sum_path);
auto less = nodes::as<nodes::Less>(if_node->cond());
auto in = nodes::as<nodes::In>(if_node->cond());
+ auto inverted = nodes::as<nodes::Not>(if_node->cond());
if (less) {
auto symbol = nodes::as<nodes::Symbol>(less->lhs());
assert(symbol);
num_params = std::max(num_params, size_t(symbol->id() + 1));
++num_less_checks;
- } else {
- assert(in);
+ } else if (in) {
auto symbol = nodes::as<nodes::Symbol>(in->child());
assert(symbol);
num_params = std::max(num_params, size_t(symbol->id() + 1));
++num_in_checks;
max_set_size = std::max(max_set_size, in->num_entries());
+ } else {
+ assert(inverted);
+ auto ge = nodes::as<nodes::GreaterEqual>(inverted->child());
+ assert(ge);
+ auto symbol = nodes::as<nodes::Symbol>(ge->lhs());
+ assert(symbol);
+ num_params = std::max(num_params, size_t(symbol->id() + 1));
+ ++num_inverted_checks;
}
return 1.0 + (p_true * true_path) + ((1.0 - p_true) * false_path);
} else {
@@ -90,6 +99,7 @@ ForestStats::ForestStats(const std::vector<const nodes::Node *> &trees)
tree_sizes(),
total_less_checks(0),
total_in_checks(0),
+ total_inverted_checks(0),
total_tuned_checks(0),
max_set_size(0),
total_expected_path_length(0.0),
@@ -104,6 +114,7 @@ ForestStats::ForestStats(const std::vector<const nodes::Node *> &trees)
++size_map[stats.size];
total_less_checks += stats.num_less_checks;
total_in_checks += stats.num_in_checks;
+ total_inverted_checks += stats.num_inverted_checks;
total_tuned_checks += stats.num_tuned_checks;
max_set_size = std::max(max_set_size, stats.max_set_size);
total_expected_path_length += stats.expected_path_length;
diff --git a/eval/src/vespa/eval/eval/gbdt.h b/eval/src/vespa/eval/eval/gbdt.h
index 378b6cc4b08..8005c7c3fc3 100644
--- a/eval/src/vespa/eval/eval/gbdt.h
+++ b/eval/src/vespa/eval/eval/gbdt.h
@@ -24,8 +24,9 @@ std::vector<const nodes::Node *> extract_trees(const nodes::Node &node);
**/
struct TreeStats {
size_t size;
- size_t num_less_checks;
- size_t num_in_checks;
+ size_t num_less_checks; // foo < 2.5
+ size_t num_in_checks; // foo in [1,2,3]
+ size_t num_inverted_checks; // !(foo >= 2.5)
size_t num_tuned_checks;
size_t max_set_size;
double expected_path_length;
@@ -49,6 +50,7 @@ struct ForestStats {
std::vector<TreeSize> tree_sizes;
size_t total_less_checks;
size_t total_in_checks;
+ size_t total_inverted_checks;
size_t total_tuned_checks;
size_t max_set_size;
double total_expected_path_length;
diff --git a/eval/src/vespa/eval/eval/vm_forest.cpp b/eval/src/vespa/eval/eval/vm_forest.cpp
index c456c660af7..127114d0ca5 100644
--- a/eval/src/vespa/eval/eval/vm_forest.cpp
+++ b/eval/src/vespa/eval/eval/vm_forest.cpp
@@ -14,9 +14,10 @@ namespace {
//-----------------------------------------------------------------------------
-constexpr uint32_t LEAF = 0;
-constexpr uint32_t LESS = 1;
-constexpr uint32_t IN = 2;
+constexpr uint32_t LEAF = 0;
+constexpr uint32_t LESS = 1;
+constexpr uint32_t IN = 2;
+constexpr uint32_t INVERTED = 3;
// layout:
//
@@ -73,7 +74,7 @@ double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node
if (node_type == LEAF) {
return *as_double_ptr(pos);
}
- } else {
+ } else if (node_type == IN) {
if (find_in(input[pos[0] >> 12], as_double_ptr(pos + 2),
as_double_ptr(pos + 2 + (2 * (pos[1] & 0xff)))))
{
@@ -86,6 +87,17 @@ double general_find_leaf(const double *input, const uint32_t *pos, uint32_t node
if (node_type == LEAF) {
return *as_double_ptr(pos);
}
+ } else {
+ if (input[pos[0] >> 12] >= *as_double_ptr(pos + 1)) {
+ node_type = (pos[0] & 0xf);
+ pos += 4 + pos[3];
+ } else {
+ node_type = (pos[0] & 0xf0) >> 4;
+ pos += 4;
+ }
+ if (node_type == LEAF) {
+ return *as_double_ptr(pos);
+ }
}
}
}
@@ -143,18 +155,42 @@ void encode_in(const nodes::In &in,
model_out[meta_idx] |= ((IN << 8) | (left_type << 4) | right_type);
}
+void encode_inverted(const nodes::Not &inverted,
+ const nodes::Node &left_child, const nodes::Node &right_child,
+ std::vector<uint32_t> &model_out)
+{
+ size_t meta_idx = model_out.size();
+ auto ge = nodes::as<nodes::GreaterEqual>(inverted.child());
+ assert(ge);
+ auto symbol = nodes::as<nodes::Symbol>(ge->lhs());
+ assert(symbol);
+ model_out.push_back(uint32_t(symbol->id()) << 12);
+ assert(ge->rhs().is_const());
+ encode_const(ge->rhs().get_const_value(), model_out);
+ size_t skip_idx = model_out.size();
+ model_out.push_back(0); // left child size placeholder
+ uint32_t left_type = encode_node(left_child, model_out);
+ model_out[skip_idx] = (model_out.size() - (skip_idx + 1));
+ uint32_t right_type = encode_node(right_child, model_out);
+ model_out[meta_idx] |= ((INVERTED << 8) | (left_type << 4) | right_type);
+}
+
uint32_t encode_node(const nodes::Node &node_in, std::vector<uint32_t> &model_out) {
auto if_node = nodes::as<nodes::If>(node_in);
if (if_node) {
auto less = nodes::as<nodes::Less>(if_node->cond());
auto in = nodes::as<nodes::In>(if_node->cond());
+ auto inverted = nodes::as<nodes::Not>(if_node->cond());
if (less) {
encode_less(*less, if_node->true_expr(), if_node->false_expr(), model_out);
return LESS;
- } else {
- assert(in);
+ } else if (in) {
encode_in(*in, if_node->true_expr(), if_node->false_expr(), model_out);
return IN;
+ } else {
+ assert(inverted);
+ encode_inverted(*inverted, if_node->true_expr(), if_node->false_expr(), model_out);
+ return INVERTED;
}
} else {
assert(node_in.is_const());
@@ -192,7 +228,7 @@ Optimize::Result
VMForest::less_only_optimize(const ForestStats &stats,
const std::vector<const nodes::Node *> &trees)
{
- if (stats.total_in_checks > 0) {
+ if ((stats.total_in_checks > 0) || (stats.total_inverted_checks > 0)) {
return Optimize::Result();
}
return optimize(trees, less_only_eval);
diff --git a/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp b/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp
index 1d316fe28d8..4fc28503fdd 100644
--- a/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp
+++ b/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp
@@ -113,10 +113,7 @@ struct FunctionInfo {
}
}
- void analyze_inputs(const Node &node) {
- for (size_t i = 0; i < node.num_children(); ++i) {
- analyze_inputs(node.get_child(i));
- }
+ void check_node(const Node &node) {
check_cmp(as<Equal>(node));
check_cmp(as<NotEqual>(node));
check_cmp(as<Approx>(node));
@@ -127,6 +124,18 @@ struct FunctionInfo {
check_in(as<In>(node));
}
+ void check_inverted(const Not *node) {
+ check_node(node->child());
+ }
+
+ void analyze_inputs(const Node &node) {
+ for (size_t i = 0; i < node.num_children(); ++i) {
+ analyze_inputs(node.get_child(i));
+ }
+ check_node(node);
+ check_inverted(as<Not>(node));
+ }
+
FunctionInfo(const Function &function)
: expression_size(count_nodes(function.root())),
root_is_forest(function.root().is_forest()),