diff options
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/gbdt/gbdt_test.cpp | 30 | ||||
-rw-r--r-- | eval/src/tests/eval/param_usage/param_usage_test.cpp | 11 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/gbdt.cpp | 20 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/gbdt.h | 9 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/compiled_function.cpp | 18 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/compiled_function.h | 1 |
6 files changed, 89 insertions, 0 deletions
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp index c96779edeb0..0535c7280bf 100644 --- a/eval/src/tests/eval/gbdt/gbdt_test.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp @@ -323,4 +323,34 @@ TEST("require that forests evaluate to approximately the same for all evaluation //----------------------------------------------------------------------------- +TEST("require that GDBT expressions can be detected") { + Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" + "if((d in 1),10.0,if((e<1),20.0,30.0))+" + "if((d in 1),10.0,if((e<1),20.0,30.0))"); + EXPECT_TRUE(contains_gbdt(function.root(), 9)); + EXPECT_TRUE(!contains_gbdt(function.root(), 10)); +} + +TEST("require that wrapped GDBT expressions can be detected") { + Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" + "if((d in 1),10.0,if((e<1),20.0,30.0))+" + "if((d in 1),10.0,if((e<1),20.0,30.0)))"); + EXPECT_TRUE(contains_gbdt(function.root(), 9)); + EXPECT_TRUE(!contains_gbdt(function.root(), 10)); +} + +TEST("require that lazy parameters are not suggested for GBDT models") { + Function function = Function::parse(Model().make_forest(10, 8)); + EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function)); +} + +TEST("require that lazy parameters can be suggested for small GBDT models") { + Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" + "if((d in 1),10.0,if((e<1),20.0,30.0))+" + "if((d in 1),10.0,if((e<1),20.0,30.0))"); + EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function)); +} + +//----------------------------------------------------------------------------- + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/eval/param_usage/param_usage_test.cpp b/eval/src/tests/eval/param_usage/param_usage_test.cpp index ff0c6667279..24478046289 100644 --- a/eval/src/tests/eval/param_usage/param_usage_test.cpp +++ b/eval/src/tests/eval/param_usage/param_usage_test.cpp @@ -2,6 +2,7 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/eval/eval/function.h> #include <vespa/eval/eval/param_usage.h> +#include <vespa/eval/eval/llvm/compiled_function.h> #include <vespa/vespalib/test/insertion_operators.h> using vespalib::approx_equal; @@ -62,4 +63,14 @@ TEST("require that multi-level if statements are combined correctly") { EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 0.5, 1.0, 1.0})); } +TEST("require that lazy parameters are suggested for functions with parameters that might not be used") { + Function function = Function::parse("if(z,x,y)+if(w,y,x)"); + EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function)); +} + +TEST("require that lazy parameters are not suggested for functions where all parameters are always used") { + Function function = Function::parse("a*b*c"); + EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function)); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp index 4b0fc8bf066..b787a9a04f0 100644 --- a/eval/src/vespa/eval/eval/gbdt.cpp +++ b/eval/src/vespa/eval/eval/gbdt.cpp @@ -2,6 +2,7 @@ #include "gbdt.h" #include "vm_forest.h" +#include "node_traverser.h" #include <vespa/eval/eval/basic_nodes.h> #include <vespa/eval/eval/call_nodes.h> #include <vespa/eval/eval/operator_nodes.h> @@ -117,6 +118,25 @@ ForestStats::ForestStats(const std::vector<const nodes::Node *> &trees) //----------------------------------------------------------------------------- +bool contains_gbdt(const nodes::Node &node, size_t limit) { + struct FindGBDT : NodeTraverser { + size_t seen; + size_t limit; + explicit FindGBDT(size_t limit_in) : seen(0), limit(limit_in) {} + bool found() const { return (seen >= limit); } + bool open(const nodes::Node &) override { return !found(); } + void close(const nodes::Node &node) override { + if (node.is_tree() || node.is_forest()) { + ++seen; + } + } + } findGBDT(limit); + node.traverse(findGBDT); + return findGBDT.found(); +} + +//----------------------------------------------------------------------------- + Optimize::Result Optimize::select_best(const ForestStats &stats, const std::vector<const nodes::Node *> &trees) diff --git a/eval/src/vespa/eval/eval/gbdt.h b/eval/src/vespa/eval/eval/gbdt.h index 6cf48528dff..eda0d16229f 100644 --- a/eval/src/vespa/eval/eval/gbdt.h +++ b/eval/src/vespa/eval/eval/gbdt.h @@ -60,6 +60,15 @@ struct ForestStats { //----------------------------------------------------------------------------- /** + * Check if the given sub-expression contains GBDT. This function + * returns true if the number of tree/forest nodes exceeds the given + * limit. + **/ +bool contains_gbdt(const nodes::Node &node, size_t limit); + +//----------------------------------------------------------------------------- + +/** * A Forest object represents deletable custom prepared state that may * be used to evaluate a GBDT forest from within LLVM generated * machine code. It is very important that the evaluation function diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp index f56272231cb..30d21a987e0 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp @@ -1,11 +1,14 @@ // Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "compiled_function.h" +#include <vespa/eval/eval/param_usage.h> +#include <vespa/eval/eval/gbdt.h> #include <vespa/eval/eval/node_traverser.h> #include <vespa/eval/eval/check_type.h> #include <vespa/eval/eval/tensor_nodes.h> #include <vespa/vespalib/util/classname.h> #include <vespa/vespalib/util/benchmark_timer.h> +#include <vespa/vespalib/util/approx.h> namespace vespalib { namespace eval { @@ -136,5 +139,20 @@ CompiledFunction::detect_issues(const Function &function) return Function::Issues(std::move(checker.issues)); } +bool +CompiledFunction::should_use_lazy_params(const Function &function) +{ + if (gbdt::contains_gbdt(function.root(), 16)) { + return false; // contains gbdt + } + auto usage = vespalib::eval::check_param_usage(function); + for (double p_use: usage) { + if (!approx_equal(p_use, 1.0)) { + return true; // param not always used + } + } + return false; // all params always used +} + } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.h b/eval/src/vespa/eval/eval/llvm/compiled_function.h index 912453f290a..13ef322e03b 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.h +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.h @@ -63,6 +63,7 @@ public: } double estimate_cost_us(const std::vector<double> ¶ms, double budget = 5.0) const; static Function::Issues detect_issues(const Function &function); + static bool should_use_lazy_params(const Function &function); }; } // namespace vespalib::eval |