aboutsummaryrefslogtreecommitdiffstats
path: root/vespalib/src/tests/eval/gbdt/gbdt_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'vespalib/src/tests/eval/gbdt/gbdt_test.cpp')
-rw-r--r--vespalib/src/tests/eval/gbdt/gbdt_test.cpp256
1 files changed, 256 insertions, 0 deletions
diff --git a/vespalib/src/tests/eval/gbdt/gbdt_test.cpp b/vespalib/src/tests/eval/gbdt/gbdt_test.cpp
new file mode 100644
index 00000000000..c5643abfd85
--- /dev/null
+++ b/vespalib/src/tests/eval/gbdt/gbdt_test.cpp
@@ -0,0 +1,256 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/vespalib/eval/gbdt.h>
+#include <vespa/vespalib/eval/vm_forest.h>
+#include <vespa/vespalib/eval/deinline_forest.h>
+#include <vespa/vespalib/eval/function.h>
+#include <vespa/vespalib/eval/interpreted_function.h>
+#include <vespa/vespalib/eval/compiled_function.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include "model.cpp"
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::nodes;
+using namespace vespalib::eval::gbdt;
+
+//-----------------------------------------------------------------------------
+
+double eval_double(const Function &function, const std::vector<double> &params) {
+ InterpretedFunction ifun(SimpleTensorEngine::ref(), function);
+ InterpretedFunction::Context ctx;
+ for (double param: params) {
+ ctx.add_param(param);
+ }
+ return ifun.eval(ctx).as_double();
+}
+
+//-----------------------------------------------------------------------------
+
+TEST("require that tree stats can be calculated") {
+ for (size_t tree_size = 2; tree_size < 64; ++tree_size) {
+ EXPECT_EQUAL(tree_size, TreeStats(Function::parse(Model().make_tree(tree_size)).root()).size);
+ }
+
+ TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))").root());
+ EXPECT_EQUAL(4u, stats1.size);
+ EXPECT_EQUAL(1u, stats1.num_less_checks);
+ EXPECT_EQUAL(2u, stats1.num_in_checks);
+ EXPECT_EQUAL(3u, stats1.max_set_size);
+
+ TreeStats stats2(Function::parse("if((d in 1),10.0,if((e<1),20.0,30.0))").root());
+ EXPECT_EQUAL(3u, stats2.size);
+ EXPECT_EQUAL(1u, stats2.num_less_checks);
+ EXPECT_EQUAL(1u, stats2.num_in_checks);
+ EXPECT_EQUAL(1u, stats2.max_set_size);
+}
+
+TEST("require that trees can be extracted from forest") {
+ for (size_t tree_size = 10; tree_size < 20; ++tree_size) {
+ for (size_t forest_size = 10; forest_size < 20; ++forest_size) {
+ vespalib::string expression = Model().make_forest(forest_size, tree_size);
+ Function function = Function::parse(expression);
+ std::vector<const Node *> trees = extract_trees(function.root());
+ EXPECT_EQUAL(forest_size, trees.size());
+ for (const Node *tree: trees) {
+ EXPECT_EQUAL(tree_size, TreeStats(*tree).size);
+ }
+ }
+ }
+}
+
+TEST("require that forest stats can be calculated") {
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))");
+ std::vector<const Node *> trees = extract_trees(function.root());
+ ForestStats stats(trees);
+ EXPECT_EQUAL(3u, stats.num_trees);
+ EXPECT_EQUAL(10u, stats.total_size);
+ ASSERT_EQUAL(2u, stats.tree_sizes.size());
+ EXPECT_EQUAL(3u, stats.tree_sizes[0].size);
+ EXPECT_EQUAL(2u, stats.tree_sizes[0].count);
+ EXPECT_EQUAL(4u, stats.tree_sizes[1].size);
+ EXPECT_EQUAL(1u, stats.tree_sizes[1].count);
+ EXPECT_EQUAL(3u, stats.total_less_checks);
+ EXPECT_EQUAL(4u, stats.total_in_checks);
+ EXPECT_EQUAL(3u, stats.max_set_size);
+}
+
+double expected_path(const vespalib::string &forest) {
+ return ForestStats(extract_trees(Function::parse(forest).root())).total_expected_path_length;
+}
+
+TEST("require that expected path length is calculated correctly") {
+ EXPECT_EQUAL(0.0, expected_path("1"));
+ EXPECT_EQUAL(0.0, expected_path("if(1,2,3)"));
+ EXPECT_EQUAL(1.0, expected_path("if(a<1,2,3)"));
+ EXPECT_EQUAL(1.0, expected_path("if(b in [1,2,3],2,3)"));
+ EXPECT_EQUAL(2.0, expected_path("if(a<1,2,3)+if(a<1,2,3)"));
+ EXPECT_EQUAL(3.0, expected_path("if(a<1,2,3)+if(a<1,2,3)+if(a<1,2,3)"));
+ EXPECT_EQUAL(0.50*1.0 + 0.50*2.0, expected_path("if(a<1,1,if(a<1,2,3))"));
+ EXPECT_EQUAL(0.25*1.0 + 0.75*2.0, expected_path("if(a<1,1,if(a<1,2,3),0.25)"));
+ EXPECT_EQUAL(0.75*1.0 + 0.25*2.0, expected_path("if(a<1,1,if(a<1,2,3),0.75)"));
+}
+
+double average_path(const vespalib::string &forest) {
+ return ForestStats(extract_trees(Function::parse(forest).root())).total_average_path_length;
+}
+
+TEST("require that average path length is calculated correctly") {
+ EXPECT_EQUAL(0.0, average_path("1"));
+ EXPECT_EQUAL(0.0, average_path("if(1,2,3)"));
+ EXPECT_EQUAL(1.0, average_path("if(a<1,2,3)"));
+ EXPECT_EQUAL(1.0, average_path("if(b in [1,2,3],2,3)"));
+ EXPECT_EQUAL(2.0, average_path("if(a<1,2,3)+if(a<1,2,3)"));
+ EXPECT_EQUAL(3.0, average_path("if(a<1,2,3)+if(a<1,2,3)+if(a<1,2,3)"));
+ EXPECT_EQUAL(5.0/3.0, average_path("if(a<1,1,if(a<1,2,3))"));
+ EXPECT_EQUAL(5.0/3.0, average_path("if(a<1,1,if(a<1,2,3),0.25)"));
+ EXPECT_EQUAL(5.0/3.0, average_path("if(a<1,1,if(a<1,2,3),0.75)"));
+}
+
+double count_tuned(const vespalib::string &forest) {
+ return ForestStats(extract_trees(Function::parse(forest).root())).total_tuned_checks;
+}
+
+TEST("require that tuned checks are counted correctly") {
+ EXPECT_EQUAL(0.0, count_tuned("if(a<1,2,3)"));
+ EXPECT_EQUAL(0.0, count_tuned("if(a<1,2,3,0.5)")); // NB: no explicit tuned flag
+ EXPECT_EQUAL(1.0, count_tuned("if(a<1,2,3,0.3)"));
+ EXPECT_EQUAL(1.0, count_tuned("if(b in [1,2,3],2,3,0.8)"));
+ EXPECT_EQUAL(2.0, count_tuned("if(a<1,2,3,0.3)+if(a<1,2,3,0.8)"));
+ EXPECT_EQUAL(3.0, count_tuned("if(a<1,2,3,0.3)+if(a<1,2,3,0.4)+if(a<1,2,3,0.9)"));
+ EXPECT_EQUAL(1.0, count_tuned("if(a<1,1,if(a<1,2,3),0.25)"));
+ EXPECT_EQUAL(2.0, count_tuned("if(a<1,1,if(a<1,2,3,0.2),0.25)"));
+}
+
+//-----------------------------------------------------------------------------
+
+struct DummyForest1 : public Forest {
+ size_t num_trees;
+ explicit DummyForest1(size_t num_trees_in) : num_trees(num_trees_in) {}
+ static double eval(const Forest *forest, const double *) {
+ const DummyForest1 &self = *((const DummyForest1 *)forest);
+ return double(self.num_trees * 2);
+ }
+ static Optimize::Result optimize(const ForestStats &stats,
+ const std::vector<const nodes::Node *> &trees)
+ {
+ if (stats.num_trees < 50) {
+ return Optimize::Result();
+ }
+ return Optimize::Result(Forest::UP(new DummyForest1(trees.size())), eval);
+ }
+};
+
+struct DummyForest2 : public Forest {
+ size_t num_trees;
+ explicit DummyForest2(size_t num_trees_in) : num_trees(num_trees_in) {}
+ static double eval(const Forest *forest, const double *) {
+ const DummyForest1 &self = *((const DummyForest1 *)forest);
+ return double(self.num_trees);
+ }
+ static Optimize::Result optimize(const ForestStats &stats,
+ const std::vector<const nodes::Node *> &trees)
+ {
+ if (stats.num_trees < 25) {
+ return Optimize::Result();
+ }
+ return Optimize::Result(Forest::UP(new DummyForest2(trees.size())), eval);
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+TEST("require that trees can be optimized by a forest optimizer") {
+ Optimize::Chain chain({DummyForest1::optimize, DummyForest2::optimize});
+ size_t tree_size = 20;
+ for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) {
+ vespalib::string expression = Model().make_forest(forest_size, tree_size);
+ Function function = Function::parse(expression);
+ CompiledFunction compiled_function(function, PassParams::ARRAY, chain);
+ std::vector<double> inputs(function.num_params(), 0.5);
+ if (forest_size < 25) {
+ EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_function()(&inputs[0]));
+ } else if (forest_size < 50) {
+ EXPECT_EQUAL(double(forest_size), compiled_function.get_function()(&inputs[0]));
+ } else {
+ EXPECT_EQUAL(double(2 * forest_size), compiled_function.get_function()(&inputs[0]));
+ }
+ }
+}
+
+//-----------------------------------------------------------------------------
+
+Optimize::Chain less_only_vm_chain({VMForest::less_only_optimize});
+Optimize::Chain general_vm_chain({VMForest::general_optimize});
+
+TEST("require that less only VM tree optimizer works") {
+ Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
+ "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ CompiledFunction compiled_function(function, PassParams::SEPARATE, less_only_vm_chain);
+ auto f = compiled_function.get_function<6>();
+ EXPECT_EQUAL(11.0, f(0.5, 0.0, 0.0, 0.5, 0.0, 0.0));
+ EXPECT_EQUAL(22.0, f(1.5, 0.5, 0.5, 1.5, 0.5, 0.5));
+ EXPECT_EQUAL(33.0, f(1.5, 0.5, 1.5, 1.5, 0.5, 1.5));
+ EXPECT_EQUAL(44.0, f(1.5, 1.5, 0.0, 1.5, 1.5, 0.0));
+}
+
+TEST("require that models with in checks are rejected by less only vm optimizer") {
+ Function function = Function::parse(Model().less_percent(100).make_forest(300, 30));
+ auto trees = extract_trees(function.root());
+ ForestStats stats(trees);
+ EXPECT_TRUE(Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
+ stats.total_in_checks = 1;
+ EXPECT_TRUE(!Optimize::apply_chain(less_only_vm_chain, stats, trees).valid());
+}
+
+TEST("require that general VM tree optimizer works") {
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
+ "if((d in 1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ CompiledFunction compiled_function(function, PassParams::SEPARATE, general_vm_chain);
+ auto f = compiled_function.get_function<6>();
+ EXPECT_EQUAL(11.0, f(0.5, 0.0, 0.0, 1.0, 0.0, 0.0));
+ EXPECT_EQUAL(22.0, f(1.5, 2.0, 1.0, 2.0, 0.5, 0.5));
+ EXPECT_EQUAL(33.0, f(1.5, 2.0, 2.0, 2.0, 0.5, 1.5));
+ EXPECT_EQUAL(44.0, f(1.5, 5.0, 0.0, 2.0, 1.5, 0.0));
+}
+
+TEST("require that models with too large sets are rejected by general vm optimizer") {
+ Function function = Function::parse(Model().less_percent(80).make_forest(300, 30));
+ auto trees = extract_trees(function.root());
+ ForestStats stats(trees);
+ EXPECT_TRUE(stats.total_in_checks > 0);
+ EXPECT_TRUE(Optimize::apply_chain(general_vm_chain, stats, trees).valid());
+ stats.max_set_size = 256;
+ EXPECT_TRUE(!Optimize::apply_chain(general_vm_chain, stats, trees).valid());
+}
+
+//-----------------------------------------------------------------------------
+
+TEST("require that forests evaluate to approximately the same for all evaluation options") {
+ for (size_t tree_size: std::vector<size_t>({20})) {
+ for (size_t num_trees: std::vector<size_t>({50})) {
+ for (size_t less_percent: std::vector<size_t>({100, 80})) {
+ vespalib::string expression = Model().less_percent(less_percent).make_forest(num_trees, tree_size);
+ Function function = Function::parse(expression);
+ CompiledFunction none(function, PassParams::ARRAY, Optimize::none);
+ CompiledFunction deinline(function, PassParams::ARRAY, DeinlineForest::optimize_chain);
+ CompiledFunction vm_forest(function, PassParams::ARRAY, VMForest::optimize_chain);
+ EXPECT_EQUAL(0u, none.get_forests().size());
+ ASSERT_EQUAL(1u, deinline.get_forests().size());
+ EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr);
+ ASSERT_EQUAL(1u, vm_forest.get_forests().size());
+ EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr);
+ std::vector<double> inputs(function.num_params(), 0.5);
+ double expected = eval_double(function, inputs);
+ EXPECT_APPROX(expected, none.get_function()(&inputs[0]), 1e-6);
+ EXPECT_APPROX(expected, deinline.get_function()(&inputs[0]), 1e-6);
+ EXPECT_APPROX(expected, vm_forest.get_function()(&inputs[0]), 1e-6);
+ }
+ }
+ }
+}
+
+//-----------------------------------------------------------------------------
+
+TEST_MAIN() { TEST_RUN_ALL(); }