aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-21 13:10:38 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-21 13:10:38 +0000
commit942dad04d13bd85822bd622af5ba08572617f2cc (patch)
tree4d5b88b40cff8f06c2742495d2d1a8985de4750d /eval
parent7a8614113524830b41485101f247dae0127aabb0 (diff)
support gbdt forest optimizations with lazy parameters
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/gbdt/gbdt_test.cpp134
-rw-r--r--eval/src/tests/eval/gbdt/model.cpp2
-rw-r--r--eval/src/vespa/eval/eval/gbdt.cpp16
-rw-r--r--eval/src/vespa/eval/eval/gbdt.h2
-rw-r--r--eval/src/vespa/eval/eval/llvm/deinline_forest.cpp3
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp98
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.h2
7 files changed, 196 insertions, 61 deletions
diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp
index 12e79941b44..b5ffb046b22 100644
--- a/eval/src/tests/eval/gbdt/gbdt_test.cpp
+++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp
@@ -24,6 +24,8 @@ double eval_double(const Function &function, const std::vector<double> &params)
return ifun.eval(ctx).as_double();
}
+double my_resolve(void *ctx, size_t idx) { return ((double*)ctx)[idx]; }
+
//-----------------------------------------------------------------------------
TEST("require that tree stats can be calculated") {
@@ -32,12 +34,14 @@ TEST("require that tree stats can be calculated") {
}
TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))").root());
+ EXPECT_EQUAL(3u, stats1.num_params);
EXPECT_EQUAL(4u, stats1.size);
EXPECT_EQUAL(1u, stats1.num_less_checks);
EXPECT_EQUAL(2u, stats1.num_in_checks);
EXPECT_EQUAL(3u, stats1.max_set_size);
TreeStats stats2(Function::parse("if((d in 1),10.0,if((e<1),20.0,30.0))").root());
+ EXPECT_EQUAL(2u, stats2.num_params);
EXPECT_EQUAL(3u, stats2.size);
EXPECT_EQUAL(1u, stats2.num_less_checks);
EXPECT_EQUAL(1u, stats2.num_in_checks);
@@ -64,6 +68,7 @@ TEST("require that forest stats can be calculated") {
"if((d in 1),10.0,if((e<1),20.0,30.0))");
std::vector<const Node *> trees = extract_trees(function.root());
ForestStats stats(trees);
+ EXPECT_EQUAL(5u, stats.num_params);
EXPECT_EQUAL(3u, stats.num_trees);
EXPECT_EQUAL(10u, stats.total_size);
ASSERT_EQUAL(2u, stats.tree_sizes.size());
@@ -125,6 +130,15 @@ TEST("require that tuned checks are counted correctly") {
//-----------------------------------------------------------------------------
+struct DummyForest0 : public Forest {
+ static double eval(const Forest *, const double *) { return 1234.0; }
+ static Optimize::Result optimize(const ForestStats &, const std::vector<const nodes::Node *> &) {
+ return Optimize::Result(Forest::UP(new DummyForest0()), eval);
+ }
+};
+
+//-----------------------------------------------------------------------------
+
struct DummyForest1 : public Forest {
size_t num_trees;
explicit DummyForest1(size_t num_trees_in) : num_trees(num_trees_in) {}
@@ -161,7 +175,26 @@ struct DummyForest2 : public Forest {
//-----------------------------------------------------------------------------
-TEST("require that trees can be optimized by a forest optimizer") {
+TEST("require that trees cannot be optimized by a forest optimizer when using SEPARATE params") {
+ Optimize::Chain chain({DummyForest0::optimize});
+ Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
+ "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
+ CompiledFunction compiled_function(function, PassParams::SEPARATE, chain);
+ CompiledFunction compiled_function_array(function, PassParams::ARRAY, chain);
+ CompiledFunction compiled_function_lazy(function, PassParams::LAZY, chain);
+ EXPECT_EQUAL(0u, compiled_function.get_forests().size());
+ EXPECT_EQUAL(1u, compiled_function_array.get_forests().size());
+ EXPECT_EQUAL(1u, compiled_function_lazy.get_forests().size());
+ auto f = compiled_function.get_function<6>();
+ auto f_array = compiled_function_array.get_function();
+ auto f_lazy = compiled_function_lazy.get_lazy_function();
+ std::vector<double> params = {1.5, 0.5, 0.5, 1.5, 0.5, 0.5};
+ EXPECT_EQUAL(22.0, f(params[0], params[1], params[2], params[3], params[4], params[5]));
+ EXPECT_EQUAL(1234.0, f_array(&params[0]));
+ EXPECT_EQUAL(1234.0, f_lazy(my_resolve, &params[0]));
+}
+
+TEST("require that trees can be optimized by a forest optimizer when using ARRAY params") {
Optimize::Chain chain({DummyForest1::optimize, DummyForest2::optimize});
size_t tree_size = 20;
for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) {
@@ -170,15 +203,39 @@ TEST("require that trees can be optimized by a forest optimizer") {
CompiledFunction compiled_function(function, PassParams::ARRAY, chain);
std::vector<double> inputs(function.num_params(), 0.5);
if (forest_size < 25) {
+ EXPECT_EQUAL(0u, compiled_function.get_forests().size());
EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_function()(&inputs[0]));
} else if (forest_size < 50) {
+ EXPECT_EQUAL(1u, compiled_function.get_forests().size());
EXPECT_EQUAL(double(forest_size), compiled_function.get_function()(&inputs[0]));
} else {
+ EXPECT_EQUAL(1u, compiled_function.get_forests().size());
EXPECT_EQUAL(double(2 * forest_size), compiled_function.get_function()(&inputs[0]));
}
}
}
+TEST("require that trees can be optimized by a forest optimizer when using LAZY params") {
+ Optimize::Chain chain({DummyForest1::optimize, DummyForest2::optimize});
+ size_t tree_size = 20;
+ for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) {
+ vespalib::string expression = Model().make_forest(forest_size, tree_size);
+ Function function = Function::parse(expression);
+ CompiledFunction compiled_function(function, PassParams::LAZY, chain);
+ std::vector<double> inputs(function.num_params(), 0.5);
+ if (forest_size < 25) {
+ EXPECT_EQUAL(0u, compiled_function.get_forests().size());
+ EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_lazy_function()(my_resolve, &inputs[0]));
+ } else if (forest_size < 50) {
+ EXPECT_EQUAL(1u, compiled_function.get_forests().size());
+ EXPECT_EQUAL(double(forest_size), compiled_function.get_lazy_function()(my_resolve, &inputs[0]));
+ } else {
+ EXPECT_EQUAL(1u, compiled_function.get_forests().size());
+ EXPECT_EQUAL(double(2 * forest_size), compiled_function.get_lazy_function()(my_resolve, &inputs[0]));
+ }
+ }
+}
+
//-----------------------------------------------------------------------------
Optimize::Chain less_only_vm_chain({VMForest::less_only_optimize});
@@ -187,12 +244,13 @@ Optimize::Chain general_vm_chain({VMForest::general_optimize});
TEST("require that less only VM tree optimizer works") {
Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+"
"if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
- CompiledFunction compiled_function(function, PassParams::SEPARATE, less_only_vm_chain);
- auto f = compiled_function.get_function<6>();
- EXPECT_EQUAL(11.0, f(0.5, 0.0, 0.0, 0.5, 0.0, 0.0));
- EXPECT_EQUAL(22.0, f(1.5, 0.5, 0.5, 1.5, 0.5, 0.5));
- EXPECT_EQUAL(33.0, f(1.5, 0.5, 1.5, 1.5, 0.5, 1.5));
- EXPECT_EQUAL(44.0, f(1.5, 1.5, 0.0, 1.5, 1.5, 0.0));
+ CompiledFunction compiled_function(function, PassParams::ARRAY, less_only_vm_chain);
+ EXPECT_EQUAL(1u, compiled_function.get_forests().size());
+ auto f = compiled_function.get_function();
+ EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 0.5, 0.0, 0.0})[0]));
+ EXPECT_EQUAL(22.0, f(&std::vector<double>({1.5, 0.5, 0.5, 1.5, 0.5, 0.5})[0]));
+ EXPECT_EQUAL(33.0, f(&std::vector<double>({1.5, 0.5, 1.5, 1.5, 0.5, 1.5})[0]));
+ EXPECT_EQUAL(44.0, f(&std::vector<double>({1.5, 1.5, 0.0, 1.5, 1.5, 0.0})[0]));
}
TEST("require that models with in checks are rejected by less only vm optimizer") {
@@ -207,12 +265,13 @@ TEST("require that models with in checks are rejected by less only vm optimizer"
TEST("require that general VM tree optimizer works") {
Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+"
"if((d in 1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))");
- CompiledFunction compiled_function(function, PassParams::SEPARATE, general_vm_chain);
- auto f = compiled_function.get_function<6>();
- EXPECT_EQUAL(11.0, f(0.5, 0.0, 0.0, 1.0, 0.0, 0.0));
- EXPECT_EQUAL(22.0, f(1.5, 2.0, 1.0, 2.0, 0.5, 0.5));
- EXPECT_EQUAL(33.0, f(1.5, 2.0, 2.0, 2.0, 0.5, 1.5));
- EXPECT_EQUAL(44.0, f(1.5, 5.0, 0.0, 2.0, 1.5, 0.0));
+ CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain);
+ EXPECT_EQUAL(1u, compiled_function.get_forests().size());
+ auto f = compiled_function.get_function();
+ EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 1.0, 0.0, 0.0})[0]));
+ EXPECT_EQUAL(22.0, f(&std::vector<double>({1.5, 2.0, 1.0, 2.0, 0.5, 0.5})[0]));
+ EXPECT_EQUAL(33.0, f(&std::vector<double>({1.5, 2.0, 2.0, 2.0, 0.5, 1.5})[0]));
+ EXPECT_EQUAL(44.0, f(&std::vector<double>({1.5, 5.0, 0.0, 2.0, 1.5, 0.0})[0]));
}
TEST("require that models with too large sets are rejected by general vm optimizer") {
@@ -227,25 +286,38 @@ TEST("require that models with too large sets are rejected by general vm optimiz
//-----------------------------------------------------------------------------
+double eval_compiled(const CompiledFunction &cfun, std::vector<double> &params) {
+ ASSERT_EQUAL(params.size(), cfun.num_params());
+ if (cfun.pass_params() == PassParams::ARRAY) {
+ return cfun.get_function()(&params[0]);
+ }
+ if (cfun.pass_params() == PassParams::LAZY) {
+ return cfun.get_lazy_function()(my_resolve, &params[0]);
+ }
+ return 31212.0;
+}
+
TEST("require that forests evaluate to approximately the same for all evaluation options") {
- for (size_t tree_size: std::vector<size_t>({20})) {
- for (size_t num_trees: std::vector<size_t>({50})) {
- for (size_t less_percent: std::vector<size_t>({100, 80})) {
- vespalib::string expression = Model().less_percent(less_percent).make_forest(num_trees, tree_size);
- Function function = Function::parse(expression);
- CompiledFunction none(function, PassParams::ARRAY, Optimize::none);
- CompiledFunction deinline(function, PassParams::ARRAY, DeinlineForest::optimize_chain);
- CompiledFunction vm_forest(function, PassParams::ARRAY, VMForest::optimize_chain);
- EXPECT_EQUAL(0u, none.get_forests().size());
- ASSERT_EQUAL(1u, deinline.get_forests().size());
- EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr);
- ASSERT_EQUAL(1u, vm_forest.get_forests().size());
- EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr);
- std::vector<double> inputs(function.num_params(), 0.5);
- double expected = eval_double(function, inputs);
- EXPECT_APPROX(expected, none.get_function()(&inputs[0]), 1e-6);
- EXPECT_APPROX(expected, deinline.get_function()(&inputs[0]), 1e-6);
- EXPECT_APPROX(expected, vm_forest.get_function()(&inputs[0]), 1e-6);
+ for (PassParams pass_params: {PassParams::ARRAY, PassParams::LAZY}) {
+ for (size_t tree_size: std::vector<size_t>({20})) {
+ for (size_t num_trees: std::vector<size_t>({10, 60})) {
+ for (size_t less_percent: std::vector<size_t>({100, 80})) {
+ vespalib::string expression = Model().less_percent(less_percent).make_forest(num_trees, tree_size);
+ Function function = Function::parse(expression);
+ CompiledFunction none(function, pass_params, Optimize::none);
+ CompiledFunction deinline(function, pass_params, DeinlineForest::optimize_chain);
+ CompiledFunction vm_forest(function, pass_params, VMForest::optimize_chain);
+ EXPECT_EQUAL(0u, none.get_forests().size());
+ ASSERT_EQUAL(1u, deinline.get_forests().size());
+ EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr);
+ ASSERT_EQUAL(1u, vm_forest.get_forests().size());
+ EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr);
+ std::vector<double> inputs(function.num_params(), 0.5);
+ double expected = eval_double(function, inputs);
+ EXPECT_APPROX(expected, eval_compiled(none, inputs), 1e-6);
+ EXPECT_APPROX(expected, eval_compiled(deinline, inputs), 1e-6);
+ EXPECT_APPROX(expected, eval_compiled(vm_forest, inputs), 1e-6);
+ }
}
}
}
diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp
index 245c111086e..06e3b657102 100644
--- a/eval/src/tests/eval/gbdt/model.cpp
+++ b/eval/src/tests/eval/gbdt/model.cpp
@@ -27,7 +27,7 @@ private:
std::string make_feature_name() {
size_t max_feature = 2;
- while ((max_feature < 1024) && (get_int(0, 99) < 50)) {
+ while ((max_feature < 1024) && (get_int(0, 99) < 55)) {
max_feature *= 2;
}
return make_string("feature_%zu", get_int(1, max_feature));
diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp
index 9e9e6d09e59..26368b1e2e4 100644
--- a/eval/src/vespa/eval/eval/gbdt.cpp
+++ b/eval/src/vespa/eval/eval/gbdt.cpp
@@ -44,7 +44,8 @@ TreeStats::TreeStats(const nodes::Node &tree)
num_tuned_checks(0),
max_set_size(0),
expected_path_length(0.0),
- average_path_length(0.0)
+ average_path_length(0.0),
+ num_params(0)
{
size_t sum_path = 0.0;
expected_path_length = traverse(tree, 0, sum_path);
@@ -53,8 +54,7 @@ TreeStats::TreeStats(const nodes::Node &tree)
double
TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) {
- auto if_node = nodes::as<nodes::If>(node);
- if (if_node) {
+ if (auto if_node = nodes::as<nodes::If>(node)) {
double p_true = if_node->p_true();
if (p_true != 0.5) {
++num_tuned_checks;
@@ -64,9 +64,15 @@ TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) {
auto less = nodes::as<nodes::Less>(if_node->cond());
auto in = nodes::as<nodes::In>(if_node->cond());
if (less) {
+ auto symbol = nodes::as<nodes::Symbol>(less->lhs());
+ assert(symbol && (symbol->id() >= 0));
+ num_params = std::max(num_params, size_t(symbol->id() + 1));
++num_less_checks;
} else {
assert(in);
+ auto symbol = nodes::as<nodes::Symbol>(in->lhs());
+ assert(symbol && (symbol->id() >= 0));
+ num_params = std::max(num_params, size_t(symbol->id() + 1));
++num_in_checks;
auto array = nodes::as<nodes::Array>(in->rhs());
size_t array_size = (array) ? array->size() : 1;
@@ -89,11 +95,13 @@ ForestStats::ForestStats(const std::vector<const nodes::Node *> &trees)
total_tuned_checks(0),
max_set_size(0),
total_expected_path_length(0.0),
- total_average_path_length(0.0)
+ total_average_path_length(0.0),
+ num_params(0)
{
std::map<size_t,size_t> size_map;
for (const nodes::Node *tree: trees) {
TreeStats stats(*tree);
+ num_params = std::max(num_params, stats.num_params);
total_size += stats.size;
++size_map[stats.size];
total_less_checks += stats.num_less_checks;
diff --git a/eval/src/vespa/eval/eval/gbdt.h b/eval/src/vespa/eval/eval/gbdt.h
index c7ec59b603c..1c22c195d46 100644
--- a/eval/src/vespa/eval/eval/gbdt.h
+++ b/eval/src/vespa/eval/eval/gbdt.h
@@ -29,6 +29,7 @@ struct TreeStats {
size_t max_set_size;
double expected_path_length;
double average_path_length;
+ size_t num_params;
explicit TreeStats(const nodes::Node &tree);
private:
double traverse(const nodes::Node &tree, size_t depth, size_t &sum_path);
@@ -51,6 +52,7 @@ struct ForestStats {
size_t max_set_size;
double total_expected_path_length;
double total_average_path_length;
+ size_t num_params;
explicit ForestStats(const std::vector<const nodes::Node *> &trees);
};
diff --git a/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp b/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp
index b0b71ed1601..c976fb811c3 100644
--- a/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp
+++ b/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp
@@ -17,7 +17,8 @@ DeinlineForest::DeinlineForest(const std::vector<const nodes::Node *> &trees)
fragment_size += TreeStats(*trees[idx]).size;
fragment.push_back(trees[idx++]);
}
- void *address = _llvm_wrapper.compile_forest_fragment(fragment);
+ ForestStats stats(fragment);
+ void *address = _llvm_wrapper.compile_forest_fragment(stats.num_params, fragment);
_fragments.push_back((array_function)address);
}
}
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 2e1627f2b97..0d404f2077b 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -25,6 +25,27 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal
double vespalib_eval_relu(double a) { return std::max(a, 0.0); }
double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
+using vespalib::eval::gbdt::Forest;
+using resolve_function = double (*)(void *ctx, size_t idx);
+double vespalib_eval_forest_proxy(Forest::eval_function eval_forest, const Forest *forest,
+ resolve_function resolve, void *ctx, size_t num_params)
+{
+ if (num_params <= 64) {
+ double params[64];
+ for (size_t i = 0; i < num_params; ++i) {
+ params[i] = resolve(ctx, i);
+ }
+ return eval_forest(forest, &params[0]);
+ } else {
+ std::vector<double> params;
+ params.reserve(num_params);
+ for (size_t i = 0; i < num_params; ++i) {
+ params.push_back(resolve(ctx, i));
+ }
+ return eval_forest(forest, &params[0]);
+ }
+}
+
namespace vespalib {
namespace eval {
@@ -55,6 +76,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
std::vector<llvm::Value*> values;
std::vector<llvm::Value*> let_values;
llvm::Function *function;
+ size_t num_params;
PassParams pass_params;
bool inside_forest;
const Node *forest_end;
@@ -62,6 +84,41 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
std::vector<gbdt::Forest::UP> &forests;
std::vector<PluginState::UP> &plugin_state;
+ llvm::PointerType *make_eval_forest_funptr_t() {
+ std::vector<llvm::Type*> param_types;
+ param_types.push_back(builder.getVoidTy()->getPointerTo());
+ param_types.push_back(builder.getDoubleTy()->getPointerTo());
+ llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
+ return llvm::PointerType::get(function_type, 0);
+ }
+
+ llvm::PointerType *make_resolve_param_funptr_t() {
+ std::vector<llvm::Type*> param_types;
+ param_types.push_back(builder.getVoidTy()->getPointerTo());
+ param_types.push_back(builder.getInt64Ty());
+ llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
+ return llvm::PointerType::get(function_type, 0);
+ }
+
+ llvm::PointerType *make_eval_forest_proxy_funptr_t() {
+ std::vector<llvm::Type*> param_types;
+ param_types.push_back(make_eval_forest_funptr_t());
+ param_types.push_back(builder.getVoidTy()->getPointerTo());
+ param_types.push_back(make_resolve_param_funptr_t());
+ param_types.push_back(builder.getVoidTy()->getPointerTo());
+ param_types.push_back(builder.getInt64Ty());
+ llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
+ return llvm::PointerType::get(function_type, 0);
+ }
+
+ llvm::PointerType *make_check_membership_funptr_t() {
+ std::vector<llvm::Type*> param_types;
+ param_types.push_back(builder.getVoidTy()->getPointerTo());
+ param_types.push_back(builder.getDoubleTy());
+ llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getInt1Ty(), param_types, false);
+ return llvm::PointerType::get(function_type, 0);
+ }
+
FunctionBuilder(llvm::ExecutionEngine &engine_in,
llvm::LLVMContext &context_in,
llvm::Module &module_in,
@@ -79,6 +136,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
values(),
let_values(),
function(nullptr),
+ num_params(num_params_in),
pass_params(pass_params_in),
inside_forest(false),
forest_end(nullptr),
@@ -93,12 +151,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
param_types.push_back(builder.getDoubleTy()->getPointerTo());
} else {
assert(pass_params == PassParams::LAZY);
- std::vector<llvm::Type*> callback_param_types;
- callback_param_types.push_back(builder.getVoidTy()->getPointerTo());
- callback_param_types.push_back(builder.getInt64Ty());
- llvm::FunctionType *callback_function_type = llvm::FunctionType::get(builder.getDoubleTy(), callback_param_types, false);
- llvm::PointerType *callback_function_pointer_type = llvm::PointerType::get(callback_function_type, 0);
- param_types.push_back(callback_function_pointer_type);
+ param_types.push_back(make_resolve_param_funptr_t());
param_types.push_back(builder.getVoidTy()->getPointerTo());
}
llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
@@ -114,6 +167,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
//-------------------------------------------------------------------------
llvm::Value *get_param(size_t idx) {
+ assert(idx < num_params);
if (pass_params == PassParams::SEPARATE) {
assert(idx < params.size());
return params[idx];
@@ -173,15 +227,17 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
forests.push_back(std::move(optimize_result.forest));
void *eval_ptr = (void *) optimize_result.eval;
gbdt::Forest *forest = forests.back().get();
- std::vector<llvm::Type*> param_types;
- param_types.push_back(builder.getVoidTy()->getPointerTo());
- param_types.push_back(builder.getDoubleTy()->getPointerTo());
- llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
- llvm::PointerType *function_pointer_type = llvm::PointerType::get(function_type, 0);
- llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), function_pointer_type, "inject_eval");
+ llvm::PointerType *eval_funptr_t = make_eval_forest_funptr_t();
+ llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), eval_funptr_t, "inject_eval");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getVoidTy()->getPointerTo(), "inject_ctx");
- assert(pass_params == PassParams::ARRAY);
- push(builder.CreateCall2(eval_fun, ctx, params[0], "call_eval"));
+ if (pass_params == PassParams::ARRAY) {
+ push(builder.CreateCall2(eval_fun, ctx, params[0], "call_eval"));
+ } else {
+ assert(pass_params == PassParams::LAZY);
+ llvm::PointerType *proxy_funptr_t = make_eval_forest_proxy_funptr_t();
+ llvm::Value *proxy_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)vespalib_eval_forest_proxy), proxy_funptr_t, "inject_eval_proxy");
+ push(builder.CreateCall5(proxy_fun, eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params)));
+ }
return true;
}
@@ -192,7 +248,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
push_double(node.get_const_value());
return false;
}
- if (!inside_forest && (pass_params == PassParams::ARRAY) && node.is_forest()) {
+ if (!inside_forest && (pass_params != PassParams::SEPARATE) && node.is_forest()) {
if (try_optimize_forest(node)) {
return false;
}
@@ -451,12 +507,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
plugin_state.emplace_back(new SetMemberHash(*array));
void *call_ptr = (void *) SetMemberHash::check_membership;
PluginState *state = plugin_state.back().get();
- std::vector<llvm::Type*> param_types;
- param_types.push_back(builder.getVoidTy()->getPointerTo());
- param_types.push_back(builder.getDoubleTy());
- llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getInt1Ty(), param_types, false);
- llvm::PointerType *function_pointer_type = llvm::PointerType::get(function_type, 0);
- llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), function_pointer_type, "inject_call_addr");
+ llvm::PointerType *funptr_t = make_check_membership_funptr_t();
+ llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx");
push(builder.CreateCall2(call_fun, ctx, lhs, "call_check_membership"));
} else {
@@ -618,12 +670,12 @@ LLVMWrapper::compile_function(size_t num_params, PassParams pass_params, const N
}
void *
-LLVMWrapper::compile_forest_fragment(const std::vector<const Node *> &fragment)
+LLVMWrapper::compile_forest_fragment(size_t num_params, const std::vector<const Node *> &fragment)
{
std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
FunctionBuilder builder(*_engine, *_context, *_module,
vespalib::make_string("f%zu", ++_num_functions),
- 0, PassParams::ARRAY,
+ num_params, PassParams::ARRAY,
gbdt::Optimize::none, _forests, _plugin_state);
builder.build_forest_fragment(fragment);
return builder.compile();
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
index d81f313a392..95092692675 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
@@ -57,7 +57,7 @@ public:
LLVMWrapper(LLVMWrapper &&rhs);
void *compile_function(size_t num_params, PassParams pass_params, const nodes::Node &root,
const gbdt::Optimize::Chain &forest_optimizers);
- void *compile_forest_fragment(const std::vector<const nodes::Node *> &fragment);
+ void *compile_forest_fragment(size_t num_params, const std::vector<const nodes::Node *> &fragment);
const std::vector<gbdt::Forest::UP> &get_forests() const { return _forests; }
void dump() const { _module->dump(); }
~LLVMWrapper();