diff options
-rw-r--r-- | eval/src/tests/eval/gbdt/gbdt_benchmark.cpp | 40 | ||||
-rw-r--r-- | eval/src/tests/eval/gbdt/gbdt_test.cpp | 134 | ||||
-rw-r--r-- | eval/src/tests/eval/gbdt/model.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/gbdt.cpp | 16 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/gbdt.h | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/deinline_forest.cpp | 3 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 98 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.h | 2 |
8 files changed, 224 insertions, 73 deletions
diff --git a/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp b/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp index 1f8cde5b1bb..820b8ec692d 100644 --- a/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_benchmark.cpp @@ -13,12 +13,15 @@ using namespace vespalib::eval; using namespace vespalib::eval::nodes; using namespace vespalib::eval::gbdt; +double budget = 2.0; + //----------------------------------------------------------------------------- struct CompileStrategy { virtual const char *name() const = 0; virtual const char *code_name() const = 0; virtual CompiledFunction compile(const Function &function) const = 0; + virtual CompiledFunction compile_lazy(const Function &function) const = 0; bool is_same(const CompileStrategy &rhs) const { return (this == &rhs); } @@ -26,41 +29,50 @@ struct CompileStrategy { }; struct NullStrategy : CompileStrategy { - virtual const char *name() const { + const char *name() const override { return "none"; } - virtual const char *code_name() const { + const char *code_name() const override { return "Optimize::none"; } - virtual CompiledFunction compile(const Function &function) const { + CompiledFunction compile(const Function &function) const override { return CompiledFunction(function, PassParams::ARRAY, Optimize::none); } + CompiledFunction compile_lazy(const Function &function) const override { + return CompiledFunction(function, PassParams::LAZY, Optimize::none); + } }; NullStrategy none; struct VMForestStrategy : CompileStrategy { - virtual const char *name() const { + const char *name() const override { return "vm-forest"; } - virtual const char *code_name() const { + const char *code_name() const override { return "VMForest::optimize_chain"; } - virtual CompiledFunction compile(const Function &function) const { + CompiledFunction compile(const Function &function) const override { return CompiledFunction(function, PassParams::ARRAY, VMForest::optimize_chain); } + CompiledFunction compile_lazy(const Function &function) const override { + return CompiledFunction(function, PassParams::LAZY, VMForest::optimize_chain); + } }; VMForestStrategy vm_forest; struct DeinlineForestStrategy : CompileStrategy { - virtual const char *name() const { + const char *name() const override { return "deinline-forest"; } - virtual const char *code_name() const { + const char *code_name() const override { return "DeinlineForest::optimize_chain"; } - virtual CompiledFunction compile(const Function &function) const { + CompiledFunction compile(const Function &function) const override { return CompiledFunction(function, PassParams::ARRAY, DeinlineForest::optimize_chain); } + CompiledFunction compile_lazy(const Function &function) const override { + return CompiledFunction(function, PassParams::LAZY, DeinlineForest::optimize_chain); + } }; DeinlineForestStrategy deinline_forest; @@ -72,6 +84,7 @@ struct Option { bool is_same(const Option &rhs) const { return strategy.is_same(rhs.strategy); } const char *name() const { return strategy.name(); } CompiledFunction compile(const Function &function) const { return strategy.compile(function); } + CompiledFunction compile_lazy(const Function &function) const { return strategy.compile_lazy(function); } const char *code_name() const { return strategy.code_name(); } }; @@ -153,11 +166,14 @@ std::vector<Option> find_order(const ForestParams ¶ms, Function forest = make_forest(params, num_trees); for (size_t i = 0; i < options.size(); ++i) { CompiledFunction compiled_function = options[i].compile(forest); + CompiledFunction compiled_function_lazy = options[i].compile_lazy(forest); std::vector<double> inputs(compiled_function.num_params(), 0.5); - results.push_back({compiled_function.estimate_cost_us(inputs), i}); - fprintf(stderr, " %20s@%6zu: %16g us (inputs: %zu)\n", + results.push_back({compiled_function.estimate_cost_us(inputs, budget), i}); + double lazy_time = compiled_function_lazy.estimate_cost_us(inputs, budget); + double lazy_factor = lazy_time / results.back().us; + fprintf(stderr, " %20s@%6zu: %16g us (inputs: %zu) [lazy: %g us, factor: %g]\n", options[i].name(), num_trees, results.back().us, - inputs.size()); + inputs.size(), lazy_time, lazy_factor); } std::sort(results.begin(), results.end()); std::vector<Option> ret; diff --git a/eval/src/tests/eval/gbdt/gbdt_test.cpp b/eval/src/tests/eval/gbdt/gbdt_test.cpp index 12e79941b44..b5ffb046b22 100644 --- a/eval/src/tests/eval/gbdt/gbdt_test.cpp +++ b/eval/src/tests/eval/gbdt/gbdt_test.cpp @@ -24,6 +24,8 @@ double eval_double(const Function &function, const std::vector<double> ¶ms) return ifun.eval(ctx).as_double(); } +double my_resolve(void *ctx, size_t idx) { return ((double*)ctx)[idx]; } + //----------------------------------------------------------------------------- TEST("require that tree stats can be calculated") { @@ -32,12 +34,14 @@ TEST("require that tree stats can be calculated") { } TreeStats stats1(Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))").root()); + EXPECT_EQUAL(3u, stats1.num_params); EXPECT_EQUAL(4u, stats1.size); EXPECT_EQUAL(1u, stats1.num_less_checks); EXPECT_EQUAL(2u, stats1.num_in_checks); EXPECT_EQUAL(3u, stats1.max_set_size); TreeStats stats2(Function::parse("if((d in 1),10.0,if((e<1),20.0,30.0))").root()); + EXPECT_EQUAL(2u, stats2.num_params); EXPECT_EQUAL(3u, stats2.size); EXPECT_EQUAL(1u, stats2.num_less_checks); EXPECT_EQUAL(1u, stats2.num_in_checks); @@ -64,6 +68,7 @@ TEST("require that forest stats can be calculated") { "if((d in 1),10.0,if((e<1),20.0,30.0))"); std::vector<const Node *> trees = extract_trees(function.root()); ForestStats stats(trees); + EXPECT_EQUAL(5u, stats.num_params); EXPECT_EQUAL(3u, stats.num_trees); EXPECT_EQUAL(10u, stats.total_size); ASSERT_EQUAL(2u, stats.tree_sizes.size()); @@ -125,6 +130,15 @@ TEST("require that tuned checks are counted correctly") { //----------------------------------------------------------------------------- +struct DummyForest0 : public Forest { + static double eval(const Forest *, const double *) { return 1234.0; } + static Optimize::Result optimize(const ForestStats &, const std::vector<const nodes::Node *> &) { + return Optimize::Result(Forest::UP(new DummyForest0()), eval); + } +}; + +//----------------------------------------------------------------------------- + struct DummyForest1 : public Forest { size_t num_trees; explicit DummyForest1(size_t num_trees_in) : num_trees(num_trees_in) {} @@ -161,7 +175,26 @@ struct DummyForest2 : public Forest { //----------------------------------------------------------------------------- -TEST("require that trees can be optimized by a forest optimizer") { +TEST("require that trees cannot be optimized by a forest optimizer when using SEPARATE params") { + Optimize::Chain chain({DummyForest0::optimize}); + Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+" + "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); + CompiledFunction compiled_function(function, PassParams::SEPARATE, chain); + CompiledFunction compiled_function_array(function, PassParams::ARRAY, chain); + CompiledFunction compiled_function_lazy(function, PassParams::LAZY, chain); + EXPECT_EQUAL(0u, compiled_function.get_forests().size()); + EXPECT_EQUAL(1u, compiled_function_array.get_forests().size()); + EXPECT_EQUAL(1u, compiled_function_lazy.get_forests().size()); + auto f = compiled_function.get_function<6>(); + auto f_array = compiled_function_array.get_function(); + auto f_lazy = compiled_function_lazy.get_lazy_function(); + std::vector<double> params = {1.5, 0.5, 0.5, 1.5, 0.5, 0.5}; + EXPECT_EQUAL(22.0, f(params[0], params[1], params[2], params[3], params[4], params[5])); + EXPECT_EQUAL(1234.0, f_array(¶ms[0])); + EXPECT_EQUAL(1234.0, f_lazy(my_resolve, ¶ms[0])); +} + +TEST("require that trees can be optimized by a forest optimizer when using ARRAY params") { Optimize::Chain chain({DummyForest1::optimize, DummyForest2::optimize}); size_t tree_size = 20; for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) { @@ -170,15 +203,39 @@ TEST("require that trees can be optimized by a forest optimizer") { CompiledFunction compiled_function(function, PassParams::ARRAY, chain); std::vector<double> inputs(function.num_params(), 0.5); if (forest_size < 25) { + EXPECT_EQUAL(0u, compiled_function.get_forests().size()); EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_function()(&inputs[0])); } else if (forest_size < 50) { + EXPECT_EQUAL(1u, compiled_function.get_forests().size()); EXPECT_EQUAL(double(forest_size), compiled_function.get_function()(&inputs[0])); } else { + EXPECT_EQUAL(1u, compiled_function.get_forests().size()); EXPECT_EQUAL(double(2 * forest_size), compiled_function.get_function()(&inputs[0])); } } } +TEST("require that trees can be optimized by a forest optimizer when using LAZY params") { + Optimize::Chain chain({DummyForest1::optimize, DummyForest2::optimize}); + size_t tree_size = 20; + for (size_t forest_size = 10; forest_size <= 100; forest_size += 10) { + vespalib::string expression = Model().make_forest(forest_size, tree_size); + Function function = Function::parse(expression); + CompiledFunction compiled_function(function, PassParams::LAZY, chain); + std::vector<double> inputs(function.num_params(), 0.5); + if (forest_size < 25) { + EXPECT_EQUAL(0u, compiled_function.get_forests().size()); + EXPECT_EQUAL(eval_double(function, inputs), compiled_function.get_lazy_function()(my_resolve, &inputs[0])); + } else if (forest_size < 50) { + EXPECT_EQUAL(1u, compiled_function.get_forests().size()); + EXPECT_EQUAL(double(forest_size), compiled_function.get_lazy_function()(my_resolve, &inputs[0])); + } else { + EXPECT_EQUAL(1u, compiled_function.get_forests().size()); + EXPECT_EQUAL(double(2 * forest_size), compiled_function.get_lazy_function()(my_resolve, &inputs[0])); + } + } +} + //----------------------------------------------------------------------------- Optimize::Chain less_only_vm_chain({VMForest::less_only_optimize}); @@ -187,12 +244,13 @@ Optimize::Chain general_vm_chain({VMForest::general_optimize}); TEST("require that less only VM tree optimizer works") { Function function = Function::parse("if((a<1),1.0,if((b<1),if((c<1),2.0,3.0),4.0))+" "if((d<1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); - CompiledFunction compiled_function(function, PassParams::SEPARATE, less_only_vm_chain); - auto f = compiled_function.get_function<6>(); - EXPECT_EQUAL(11.0, f(0.5, 0.0, 0.0, 0.5, 0.0, 0.0)); - EXPECT_EQUAL(22.0, f(1.5, 0.5, 0.5, 1.5, 0.5, 0.5)); - EXPECT_EQUAL(33.0, f(1.5, 0.5, 1.5, 1.5, 0.5, 1.5)); - EXPECT_EQUAL(44.0, f(1.5, 1.5, 0.0, 1.5, 1.5, 0.0)); + CompiledFunction compiled_function(function, PassParams::ARRAY, less_only_vm_chain); + EXPECT_EQUAL(1u, compiled_function.get_forests().size()); + auto f = compiled_function.get_function(); + EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 0.5, 0.0, 0.0})[0])); + EXPECT_EQUAL(22.0, f(&std::vector<double>({1.5, 0.5, 0.5, 1.5, 0.5, 0.5})[0])); + EXPECT_EQUAL(33.0, f(&std::vector<double>({1.5, 0.5, 1.5, 1.5, 0.5, 1.5})[0])); + EXPECT_EQUAL(44.0, f(&std::vector<double>({1.5, 1.5, 0.0, 1.5, 1.5, 0.0})[0])); } TEST("require that models with in checks are rejected by less only vm optimizer") { @@ -207,12 +265,13 @@ TEST("require that models with in checks are rejected by less only vm optimizer" TEST("require that general VM tree optimizer works") { Function function = Function::parse("if((a<1),1.0,if((b in [1,2,3]),if((c in 1),2.0,3.0),4.0))+" "if((d in 1),10.0,if((e<1),if((f<1),20.0,30.0),40.0))"); - CompiledFunction compiled_function(function, PassParams::SEPARATE, general_vm_chain); - auto f = compiled_function.get_function<6>(); - EXPECT_EQUAL(11.0, f(0.5, 0.0, 0.0, 1.0, 0.0, 0.0)); - EXPECT_EQUAL(22.0, f(1.5, 2.0, 1.0, 2.0, 0.5, 0.5)); - EXPECT_EQUAL(33.0, f(1.5, 2.0, 2.0, 2.0, 0.5, 1.5)); - EXPECT_EQUAL(44.0, f(1.5, 5.0, 0.0, 2.0, 1.5, 0.0)); + CompiledFunction compiled_function(function, PassParams::ARRAY, general_vm_chain); + EXPECT_EQUAL(1u, compiled_function.get_forests().size()); + auto f = compiled_function.get_function(); + EXPECT_EQUAL(11.0, f(&std::vector<double>({0.5, 0.0, 0.0, 1.0, 0.0, 0.0})[0])); + EXPECT_EQUAL(22.0, f(&std::vector<double>({1.5, 2.0, 1.0, 2.0, 0.5, 0.5})[0])); + EXPECT_EQUAL(33.0, f(&std::vector<double>({1.5, 2.0, 2.0, 2.0, 0.5, 1.5})[0])); + EXPECT_EQUAL(44.0, f(&std::vector<double>({1.5, 5.0, 0.0, 2.0, 1.5, 0.0})[0])); } TEST("require that models with too large sets are rejected by general vm optimizer") { @@ -227,25 +286,38 @@ TEST("require that models with too large sets are rejected by general vm optimiz //----------------------------------------------------------------------------- +double eval_compiled(const CompiledFunction &cfun, std::vector<double> ¶ms) { + ASSERT_EQUAL(params.size(), cfun.num_params()); + if (cfun.pass_params() == PassParams::ARRAY) { + return cfun.get_function()(¶ms[0]); + } + if (cfun.pass_params() == PassParams::LAZY) { + return cfun.get_lazy_function()(my_resolve, ¶ms[0]); + } + return 31212.0; +} + TEST("require that forests evaluate to approximately the same for all evaluation options") { - for (size_t tree_size: std::vector<size_t>({20})) { - for (size_t num_trees: std::vector<size_t>({50})) { - for (size_t less_percent: std::vector<size_t>({100, 80})) { - vespalib::string expression = Model().less_percent(less_percent).make_forest(num_trees, tree_size); - Function function = Function::parse(expression); - CompiledFunction none(function, PassParams::ARRAY, Optimize::none); - CompiledFunction deinline(function, PassParams::ARRAY, DeinlineForest::optimize_chain); - CompiledFunction vm_forest(function, PassParams::ARRAY, VMForest::optimize_chain); - EXPECT_EQUAL(0u, none.get_forests().size()); - ASSERT_EQUAL(1u, deinline.get_forests().size()); - EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr); - ASSERT_EQUAL(1u, vm_forest.get_forests().size()); - EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr); - std::vector<double> inputs(function.num_params(), 0.5); - double expected = eval_double(function, inputs); - EXPECT_APPROX(expected, none.get_function()(&inputs[0]), 1e-6); - EXPECT_APPROX(expected, deinline.get_function()(&inputs[0]), 1e-6); - EXPECT_APPROX(expected, vm_forest.get_function()(&inputs[0]), 1e-6); + for (PassParams pass_params: {PassParams::ARRAY, PassParams::LAZY}) { + for (size_t tree_size: std::vector<size_t>({20})) { + for (size_t num_trees: std::vector<size_t>({10, 60})) { + for (size_t less_percent: std::vector<size_t>({100, 80})) { + vespalib::string expression = Model().less_percent(less_percent).make_forest(num_trees, tree_size); + Function function = Function::parse(expression); + CompiledFunction none(function, pass_params, Optimize::none); + CompiledFunction deinline(function, pass_params, DeinlineForest::optimize_chain); + CompiledFunction vm_forest(function, pass_params, VMForest::optimize_chain); + EXPECT_EQUAL(0u, none.get_forests().size()); + ASSERT_EQUAL(1u, deinline.get_forests().size()); + EXPECT_TRUE(dynamic_cast<DeinlineForest*>(deinline.get_forests()[0].get()) != nullptr); + ASSERT_EQUAL(1u, vm_forest.get_forests().size()); + EXPECT_TRUE(dynamic_cast<VMForest*>(vm_forest.get_forests()[0].get()) != nullptr); + std::vector<double> inputs(function.num_params(), 0.5); + double expected = eval_double(function, inputs); + EXPECT_APPROX(expected, eval_compiled(none, inputs), 1e-6); + EXPECT_APPROX(expected, eval_compiled(deinline, inputs), 1e-6); + EXPECT_APPROX(expected, eval_compiled(vm_forest, inputs), 1e-6); + } } } } diff --git a/eval/src/tests/eval/gbdt/model.cpp b/eval/src/tests/eval/gbdt/model.cpp index 245c111086e..06e3b657102 100644 --- a/eval/src/tests/eval/gbdt/model.cpp +++ b/eval/src/tests/eval/gbdt/model.cpp @@ -27,7 +27,7 @@ private: std::string make_feature_name() { size_t max_feature = 2; - while ((max_feature < 1024) && (get_int(0, 99) < 50)) { + while ((max_feature < 1024) && (get_int(0, 99) < 55)) { max_feature *= 2; } return make_string("feature_%zu", get_int(1, max_feature)); diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp index 9e9e6d09e59..26368b1e2e4 100644 --- a/eval/src/vespa/eval/eval/gbdt.cpp +++ b/eval/src/vespa/eval/eval/gbdt.cpp @@ -44,7 +44,8 @@ TreeStats::TreeStats(const nodes::Node &tree) num_tuned_checks(0), max_set_size(0), expected_path_length(0.0), - average_path_length(0.0) + average_path_length(0.0), + num_params(0) { size_t sum_path = 0.0; expected_path_length = traverse(tree, 0, sum_path); @@ -53,8 +54,7 @@ TreeStats::TreeStats(const nodes::Node &tree) double TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) { - auto if_node = nodes::as<nodes::If>(node); - if (if_node) { + if (auto if_node = nodes::as<nodes::If>(node)) { double p_true = if_node->p_true(); if (p_true != 0.5) { ++num_tuned_checks; @@ -64,9 +64,15 @@ TreeStats::traverse(const nodes::Node &node, size_t depth, size_t &sum_path) { auto less = nodes::as<nodes::Less>(if_node->cond()); auto in = nodes::as<nodes::In>(if_node->cond()); if (less) { + auto symbol = nodes::as<nodes::Symbol>(less->lhs()); + assert(symbol && (symbol->id() >= 0)); + num_params = std::max(num_params, size_t(symbol->id() + 1)); ++num_less_checks; } else { assert(in); + auto symbol = nodes::as<nodes::Symbol>(in->lhs()); + assert(symbol && (symbol->id() >= 0)); + num_params = std::max(num_params, size_t(symbol->id() + 1)); ++num_in_checks; auto array = nodes::as<nodes::Array>(in->rhs()); size_t array_size = (array) ? array->size() : 1; @@ -89,11 +95,13 @@ ForestStats::ForestStats(const std::vector<const nodes::Node *> &trees) total_tuned_checks(0), max_set_size(0), total_expected_path_length(0.0), - total_average_path_length(0.0) + total_average_path_length(0.0), + num_params(0) { std::map<size_t,size_t> size_map; for (const nodes::Node *tree: trees) { TreeStats stats(*tree); + num_params = std::max(num_params, stats.num_params); total_size += stats.size; ++size_map[stats.size]; total_less_checks += stats.num_less_checks; diff --git a/eval/src/vespa/eval/eval/gbdt.h b/eval/src/vespa/eval/eval/gbdt.h index c7ec59b603c..1c22c195d46 100644 --- a/eval/src/vespa/eval/eval/gbdt.h +++ b/eval/src/vespa/eval/eval/gbdt.h @@ -29,6 +29,7 @@ struct TreeStats { size_t max_set_size; double expected_path_length; double average_path_length; + size_t num_params; explicit TreeStats(const nodes::Node &tree); private: double traverse(const nodes::Node &tree, size_t depth, size_t &sum_path); @@ -51,6 +52,7 @@ struct ForestStats { size_t max_set_size; double total_expected_path_length; double total_average_path_length; + size_t num_params; explicit ForestStats(const std::vector<const nodes::Node *> &trees); }; diff --git a/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp b/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp index b0b71ed1601..c976fb811c3 100644 --- a/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp +++ b/eval/src/vespa/eval/eval/llvm/deinline_forest.cpp @@ -17,7 +17,8 @@ DeinlineForest::DeinlineForest(const std::vector<const nodes::Node *> &trees) fragment_size += TreeStats(*trees[idx]).size; fragment.push_back(trees[idx++]); } - void *address = _llvm_wrapper.compile_forest_fragment(fragment); + ForestStats stats(fragment); + void *address = _llvm_wrapper.compile_forest_fragment(stats.num_params, fragment); _fragments.push_back((array_function)address); } } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 2e1627f2b97..0d404f2077b 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -25,6 +25,27 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal double vespalib_eval_relu(double a) { return std::max(a, 0.0); } double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } +using vespalib::eval::gbdt::Forest; +using resolve_function = double (*)(void *ctx, size_t idx); +double vespalib_eval_forest_proxy(Forest::eval_function eval_forest, const Forest *forest, + resolve_function resolve, void *ctx, size_t num_params) +{ + if (num_params <= 64) { + double params[64]; + for (size_t i = 0; i < num_params; ++i) { + params[i] = resolve(ctx, i); + } + return eval_forest(forest, ¶ms[0]); + } else { + std::vector<double> params; + params.reserve(num_params); + for (size_t i = 0; i < num_params; ++i) { + params.push_back(resolve(ctx, i)); + } + return eval_forest(forest, ¶ms[0]); + } +} + namespace vespalib { namespace eval { @@ -55,6 +76,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { std::vector<llvm::Value*> values; std::vector<llvm::Value*> let_values; llvm::Function *function; + size_t num_params; PassParams pass_params; bool inside_forest; const Node *forest_end; @@ -62,6 +84,41 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { std::vector<gbdt::Forest::UP> &forests; std::vector<PluginState::UP> &plugin_state; + llvm::PointerType *make_eval_forest_funptr_t() { + std::vector<llvm::Type*> param_types; + param_types.push_back(builder.getVoidTy()->getPointerTo()); + param_types.push_back(builder.getDoubleTy()->getPointerTo()); + llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false); + return llvm::PointerType::get(function_type, 0); + } + + llvm::PointerType *make_resolve_param_funptr_t() { + std::vector<llvm::Type*> param_types; + param_types.push_back(builder.getVoidTy()->getPointerTo()); + param_types.push_back(builder.getInt64Ty()); + llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false); + return llvm::PointerType::get(function_type, 0); + } + + llvm::PointerType *make_eval_forest_proxy_funptr_t() { + std::vector<llvm::Type*> param_types; + param_types.push_back(make_eval_forest_funptr_t()); + param_types.push_back(builder.getVoidTy()->getPointerTo()); + param_types.push_back(make_resolve_param_funptr_t()); + param_types.push_back(builder.getVoidTy()->getPointerTo()); + param_types.push_back(builder.getInt64Ty()); + llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false); + return llvm::PointerType::get(function_type, 0); + } + + llvm::PointerType *make_check_membership_funptr_t() { + std::vector<llvm::Type*> param_types; + param_types.push_back(builder.getVoidTy()->getPointerTo()); + param_types.push_back(builder.getDoubleTy()); + llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getInt1Ty(), param_types, false); + return llvm::PointerType::get(function_type, 0); + } + FunctionBuilder(llvm::ExecutionEngine &engine_in, llvm::LLVMContext &context_in, llvm::Module &module_in, @@ -79,6 +136,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { values(), let_values(), function(nullptr), + num_params(num_params_in), pass_params(pass_params_in), inside_forest(false), forest_end(nullptr), @@ -93,12 +151,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { param_types.push_back(builder.getDoubleTy()->getPointerTo()); } else { assert(pass_params == PassParams::LAZY); - std::vector<llvm::Type*> callback_param_types; - callback_param_types.push_back(builder.getVoidTy()->getPointerTo()); - callback_param_types.push_back(builder.getInt64Ty()); - llvm::FunctionType *callback_function_type = llvm::FunctionType::get(builder.getDoubleTy(), callback_param_types, false); - llvm::PointerType *callback_function_pointer_type = llvm::PointerType::get(callback_function_type, 0); - param_types.push_back(callback_function_pointer_type); + param_types.push_back(make_resolve_param_funptr_t()); param_types.push_back(builder.getVoidTy()->getPointerTo()); } llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false); @@ -114,6 +167,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { //------------------------------------------------------------------------- llvm::Value *get_param(size_t idx) { + assert(idx < num_params); if (pass_params == PassParams::SEPARATE) { assert(idx < params.size()); return params[idx]; @@ -173,15 +227,17 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { forests.push_back(std::move(optimize_result.forest)); void *eval_ptr = (void *) optimize_result.eval; gbdt::Forest *forest = forests.back().get(); - std::vector<llvm::Type*> param_types; - param_types.push_back(builder.getVoidTy()->getPointerTo()); - param_types.push_back(builder.getDoubleTy()->getPointerTo()); - llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false); - llvm::PointerType *function_pointer_type = llvm::PointerType::get(function_type, 0); - llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), function_pointer_type, "inject_eval"); + llvm::PointerType *eval_funptr_t = make_eval_forest_funptr_t(); + llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), eval_funptr_t, "inject_eval"); llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getVoidTy()->getPointerTo(), "inject_ctx"); - assert(pass_params == PassParams::ARRAY); - push(builder.CreateCall2(eval_fun, ctx, params[0], "call_eval")); + if (pass_params == PassParams::ARRAY) { + push(builder.CreateCall2(eval_fun, ctx, params[0], "call_eval")); + } else { + assert(pass_params == PassParams::LAZY); + llvm::PointerType *proxy_funptr_t = make_eval_forest_proxy_funptr_t(); + llvm::Value *proxy_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)vespalib_eval_forest_proxy), proxy_funptr_t, "inject_eval_proxy"); + push(builder.CreateCall5(proxy_fun, eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params))); + } return true; } @@ -192,7 +248,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { push_double(node.get_const_value()); return false; } - if (!inside_forest && (pass_params == PassParams::ARRAY) && node.is_forest()) { + if (!inside_forest && (pass_params != PassParams::SEPARATE) && node.is_forest()) { if (try_optimize_forest(node)) { return false; } @@ -451,12 +507,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { plugin_state.emplace_back(new SetMemberHash(*array)); void *call_ptr = (void *) SetMemberHash::check_membership; PluginState *state = plugin_state.back().get(); - std::vector<llvm::Type*> param_types; - param_types.push_back(builder.getVoidTy()->getPointerTo()); - param_types.push_back(builder.getDoubleTy()); - llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getInt1Ty(), param_types, false); - llvm::PointerType *function_pointer_type = llvm::PointerType::get(function_type, 0); - llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), function_pointer_type, "inject_call_addr"); + llvm::PointerType *funptr_t = make_check_membership_funptr_t(); + llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr"); llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx"); push(builder.CreateCall2(call_fun, ctx, lhs, "call_check_membership")); } else { @@ -618,12 +670,12 @@ LLVMWrapper::compile_function(size_t num_params, PassParams pass_params, const N } void * -LLVMWrapper::compile_forest_fragment(const std::vector<const Node *> &fragment) +LLVMWrapper::compile_forest_fragment(size_t num_params, const std::vector<const Node *> &fragment) { std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); FunctionBuilder builder(*_engine, *_context, *_module, vespalib::make_string("f%zu", ++_num_functions), - 0, PassParams::ARRAY, + num_params, PassParams::ARRAY, gbdt::Optimize::none, _forests, _plugin_state); builder.build_forest_fragment(fragment); return builder.compile(); diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h index d81f313a392..95092692675 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h @@ -57,7 +57,7 @@ public: LLVMWrapper(LLVMWrapper &&rhs); void *compile_function(size_t num_params, PassParams pass_params, const nodes::Node &root, const gbdt::Optimize::Chain &forest_optimizers); - void *compile_forest_fragment(const std::vector<const nodes::Node *> &fragment); + void *compile_forest_fragment(size_t num_params, const std::vector<const nodes::Node *> &fragment); const std::vector<gbdt::Forest::UP> &get_forests() const { return _forests; } void dump() const { _module->dump(); } ~LLVMWrapper(); |