summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-06-07 11:59:22 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-06-07 13:53:05 +0000
commit7f86d55cb2ddfd481bcb9533d8e06aeb0fa3a746 (patch)
tree01322584ff27ff98f4c5e77b2344204f2086e01e /eval
parent78016e2530d5c4408bbe04ac492b9e73405d2b43 (diff)
auto-detect appropriate compiled function parameter lazyness
... but still allow config override
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_test.cpp30
-rw-r--r--eval/src/tests/eval/param_usage/param_usage_test.cpp11
-rw-r--r--eval/src/vespa/eval/eval/gbdt.cpp20
-rw-r--r--eval/src/vespa/eval/eval/gbdt.h9
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp18
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.h1
6 files changed, 89 insertions, 0 deletions
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp
index c96779edeb0..0535c7280bf 100644
--- a/eval/src/tests/eval/gbdt/gbdt_test.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp
@@ -323,4 +323,34 @@ TEST("require that forests evaluate to approximately the same for all evaluation
//-----------------------------------------------------------------------------
+TEST("require that GDBT expressions can be detected") {
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))");
+ EXPECT_TRUE(contains_gbdt(function.root(), 9));
+ EXPECT_TRUE(!contains_gbdt(function.root(), 10));
+}
+
+TEST("require that wrapped GDBT expressions can be detected") {
+ Function function = Function::parse("10*(if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0)))");
+ EXPECT_TRUE(contains_gbdt(function.root(), 9));
+ EXPECT_TRUE(!contains_gbdt(function.root(), 10));
+}
+
+TEST("require that lazy parameters are not suggested for GBDT models") {
+ Function function = Function::parse(Model().make_forest(10, 8));
+ EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function));
+}
+
+TEST("require that lazy parameters can be suggested for small GBDT models") {
+ Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))+"
+ "if((d in 1),10.0,if((e<1),20.0,30.0))");
+ EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function));
+}
+
+//-----------------------------------------------------------------------------
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/eval/param_usage/param_usage_test.cpp b/eval/src/tests/eval/param_usage/param_usage_test.cpp
index ff0c6667279..24478046289 100644
--- a/eval/src/tests/eval/param_usage/param_usage_test.cpp
+++ b/eval/src/tests/eval/param_usage/param_usage_test.cpp
@@ -2,6 +2,7 @@
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/function.h>
#include <vespa/eval/eval/param_usage.h>
+#include <vespa/eval/eval/llvm/compiled_function.h>
#include <vespa/vespalib/test/insertion_operators.h>
using vespalib::approx_equal;
@@ -62,4 +63,14 @@ TEST("require that multi-level if statements are combined correctly") {
EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 0.5, 1.0, 1.0}));
}
+TEST("require that lazy parameters are suggested for functions with parameters that might not be used") {
+ Function function = Function::parse("if(z,x,y)+if(w,y,x)");
+ EXPECT_TRUE(CompiledFunction::should_use_lazy_params(function));
+}
+
+TEST("require that lazy parameters are not suggested for functions where all parameters are always used") {
+ Function function = Function::parse("a*b*c");
+ EXPECT_TRUE(!CompiledFunction::should_use_lazy_params(function));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp
index 4b0fc8bf066..b787a9a04f0 100644
--- a/eval/src/vespa/eval/eval/gbdt.cpp
+++ b/eval/src/vespa/eval/eval/gbdt.cpp
@@ -2,6 +2,7 @@
#include "gbdt.h"
#include "vm_forest.h"
+#include "node_traverser.h"
#include <vespa/eval/eval/basic_nodes.h>
#include <vespa/eval/eval/call_nodes.h>
#include <vespa/eval/eval/operator_nodes.h>
@@ -117,6 +118,25 @@ ForestStats::ForestStats(const std::vector<const nodes::Node *> &trees)
//-----------------------------------------------------------------------------
+bool contains_gbdt(const nodes::Node &node, size_t limit) {
+ struct FindGBDT : NodeTraverser {
+ size_t seen;
+ size_t limit;
+ explicit FindGBDT(size_t limit_in) : seen(0), limit(limit_in) {}
+ bool found() const { return (seen >= limit); }
+ bool open(const nodes::Node &) override { return !found(); }
+ void close(const nodes::Node &node) override {
+ if (node.is_tree() || node.is_forest()) {
+ ++seen;
+ }
+ }
+ } findGBDT(limit);
+ node.traverse(findGBDT);
+ return findGBDT.found();
+}
+
+//-----------------------------------------------------------------------------
+
Optimize::Result
Optimize::select_best(const ForestStats &stats,
const std::vector<const nodes::Node *> &trees)
diff --git a/eval/src/vespa/eval/eval/gbdt.h b/eval/src/vespa/eval/eval/gbdt.h
index 6cf48528dff..eda0d16229f 100644
--- a/eval/src/vespa/eval/eval/gbdt.h
+++ b/eval/src/vespa/eval/eval/gbdt.h
@@ -60,6 +60,15 @@ struct ForestStats {
//-----------------------------------------------------------------------------
/**
+ * Check if the given sub-expression contains GBDT. This function
+ * returns true if the number of tree/forest nodes exceeds the given
+ * limit.
+ **/
+bool contains_gbdt(const nodes::Node &node, size_t limit);
+
+//-----------------------------------------------------------------------------
+
+/**
* A Forest object represents deletable custom prepared state that may
* be used to evaluate a GBDT forest from within LLVM generated
* machine code. It is very important that the evaluation function
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index f56272231cb..30d21a987e0 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -1,11 +1,14 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "compiled_function.h"
+#include <vespa/eval/eval/param_usage.h>
+#include <vespa/eval/eval/gbdt.h>
#include <vespa/eval/eval/node_traverser.h>
#include <vespa/eval/eval/check_type.h>
#include <vespa/eval/eval/tensor_nodes.h>
#include <vespa/vespalib/util/classname.h>
#include <vespa/vespalib/util/benchmark_timer.h>
+#include <vespa/vespalib/util/approx.h>
namespace vespalib {
namespace eval {
@@ -136,5 +139,20 @@ CompiledFunction::detect_issues(const Function &function)
return Function::Issues(std::move(checker.issues));
}
+bool
+CompiledFunction::should_use_lazy_params(const Function &function)
+{
+ if (gbdt::contains_gbdt(function.root(), 16)) {
+ return false; // contains gbdt
+ }
+ auto usage = vespalib::eval::check_param_usage(function);
+ for (double p_use: usage) {
+ if (!approx_equal(p_use, 1.0)) {
+ return true; // param not always used
+ }
+ }
+ return false; // all params always used
+}
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.h b/eval/src/vespa/eval/eval/llvm/compiled_function.h
index 912453f290a..13ef322e03b 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.h
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.h
@@ -63,6 +63,7 @@ public:
}
double estimate_cost_us(const std::vector<double> &params, double budget = 5.0) const;
static Function::Issues detect_issues(const Function &function);
+ static bool should_use_lazy_params(const Function &function);
};
} // namespace vespalib::eval