From 93919968eef0d1ab16a6de77de0feec7d698dfb6 Mon Sep 17 00:00:00 2001 From: Haavard Date: Fri, 17 Feb 2017 15:52:17 +0000 Subject: added support for lazy parameters to compiled functions --- .../eval/compile_cache/compile_cache_test.cpp | 3 +- .../compiled_function/compiled_function_test.cpp | 11 ++++++ eval/src/vespa/eval/eval/function.h | 2 +- eval/src/vespa/eval/eval/key_gen.cpp | 3 +- eval/src/vespa/eval/eval/key_gen.h | 2 +- .../src/vespa/eval/eval/llvm/compiled_function.cpp | 2 +- eval/src/vespa/eval/eval/llvm/compiled_function.h | 7 ++++ eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 44 ++++++++++++++-------- eval/src/vespa/eval/eval/llvm/llvm_wrapper.h | 2 +- 9 files changed, 54 insertions(+), 22 deletions(-) (limited to 'eval/src') diff --git a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp index ffcb623c00e..dad8973fb63 100644 --- a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp +++ b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp @@ -72,7 +72,8 @@ struct CheckKeys : test::EvalSpec::EvalTest { Function function = Function::parse(param_names, expression); if (!CompiledFunction::detect_issues(function)) { if (check_key(gen_key(function, PassParams::ARRAY)) || - check_key(gen_key(function, PassParams::SEPARATE))) + check_key(gen_key(function, PassParams::SEPARATE)) || + check_key(gen_key(function, PassParams::LAZY))) { failed = true; fprintf(stderr, "key collision for: %s\n", expression.c_str()); diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp index 953fe74ae6b..6a1b1f587e3 100644 --- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp +++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp @@ -35,6 +35,17 @@ TEST("require that array parameter passing works") { EXPECT_EQUAL(45.0, arr_fun(&std::vector({9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0})[0])); } +double my_resolve(void *ctx, size_t idx) { return ((double *)ctx)[idx]; } + +TEST("require that lazy parameter passing works") { + CompiledFunction lazy_cf(Function::parse(params_10, expr_10), PassParams::LAZY); + auto lazy_fun = lazy_cf.get_lazy_function(); + EXPECT_EQUAL(10.0, lazy_fun(my_resolve, &std::vector({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})[0])); + EXPECT_EQUAL(50.0, lazy_fun(my_resolve, &std::vector({5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0})[0])); + EXPECT_EQUAL(45.0, lazy_fun(my_resolve, &std::vector({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0})[0])); + EXPECT_EQUAL(45.0, lazy_fun(my_resolve, &std::vector({9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0})[0])); +} + //----------------------------------------------------------------------------- std::vector unsupported = { diff --git a/eval/src/vespa/eval/eval/function.h b/eval/src/vespa/eval/eval/function.h index 35a89ce6512..c1141671e49 100644 --- a/eval/src/vespa/eval/eval/function.h +++ b/eval/src/vespa/eval/eval/function.h @@ -14,7 +14,7 @@ namespace vespalib { namespace eval { -enum class PassParams { SEPARATE, ARRAY }; +enum class PassParams : uint8_t { SEPARATE, ARRAY, LAZY }; /** * Interface used to perform custom symbol extraction. This is diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index 3d0f1f67e29..861bcd9b904 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -22,7 +22,6 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void add_int(int value) { key.append(&value, sizeof(value)); } void add_hash(uint32_t value) { key.append(&value, sizeof(value)); } void add_byte(uint8_t value) { key.append(&value, sizeof(value)); } - void add_bool(bool value) { key.push_back(value ? '1' : '0'); } // visit virtual void visit(const Number &node) { add_byte( 1); add_double(node.value()); } @@ -92,7 +91,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { vespalib::string gen_key(const Function &function, PassParams pass_params) { KeyGen key_gen; - key_gen.add_bool(pass_params == PassParams::ARRAY); + key_gen.add_byte(uint8_t(pass_params)); key_gen.add_size(function.num_params()); function.root().traverse(key_gen); return key_gen.key; diff --git a/eval/src/vespa/eval/eval/key_gen.h b/eval/src/vespa/eval/eval/key_gen.h index c8479b1b457..d8dcf4f2f04 100644 --- a/eval/src/vespa/eval/eval/key_gen.h +++ b/eval/src/vespa/eval/eval/key_gen.h @@ -8,7 +8,7 @@ namespace vespalib { namespace eval { class Function; -enum class PassParams; +enum class PassParams : uint8_t; /** * Function used to generate a binary key that may be used to query diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp index 694613ea354..8bf9b08a1bd 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp @@ -25,7 +25,7 @@ CompiledFunction::CompiledFunction(const Function &function_in, PassParams pass_ _pass_params(pass_params_in) { _address = _llvm_wrapper.compile_function(function_in.num_params(), - (_pass_params == PassParams::ARRAY), + _pass_params, function_in.root(), forest_optimizers); } diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.h b/eval/src/vespa/eval/eval/llvm/compiled_function.h index 41754cfb413..7f9e7c7657c 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.h +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.h @@ -26,6 +26,9 @@ public: using array_function = double (*)(const double *); + using resolve_function = double (*)(void *ctx, size_t idx); + using lazy_function = double (*)(resolve_function, void *ctx); + private: LLVMWrapper _llvm_wrapper; void *_address; @@ -51,6 +54,10 @@ public: assert(_pass_params == PassParams::ARRAY); return ((array_function)_address); } + lazy_function get_lazy_function() const { + assert(_pass_params == PassParams::LAZY); + return ((lazy_function)_address); + } const std::vector &get_forests() const { return _llvm_wrapper.get_forests(); } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index e917261240f..2e1627f2b97 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -55,7 +55,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { std::vector values; std::vector let_values; llvm::Function *function; - bool use_array; + PassParams pass_params; bool inside_forest; const Node *forest_end; const gbdt::Optimize::Chain &forest_optimizers; @@ -67,7 +67,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { llvm::Module &module_in, const vespalib::string &name_in, size_t num_params_in, - bool use_array_in, + PassParams pass_params_in, const gbdt::Optimize::Chain &forest_optimizers_in, std::vector &forests_out, std::vector &plugin_state_out) @@ -79,7 +79,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { values(), let_values(), function(nullptr), - use_array(use_array_in), + pass_params(pass_params_in), inside_forest(false), forest_end(nullptr), forest_optimizers(forest_optimizers_in), @@ -87,10 +87,19 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { plugin_state(plugin_state_out) { std::vector param_types; - if (use_array_in) { + if (pass_params == PassParams::SEPARATE) { + param_types.resize(num_params_in, builder.getDoubleTy()); + } else if (pass_params == PassParams::ARRAY) { param_types.push_back(builder.getDoubleTy()->getPointerTo()); } else { - param_types.resize(num_params_in, builder.getDoubleTy()); + assert(pass_params == PassParams::LAZY); + std::vector callback_param_types; + callback_param_types.push_back(builder.getVoidTy()->getPointerTo()); + callback_param_types.push_back(builder.getInt64Ty()); + llvm::FunctionType *callback_function_type = llvm::FunctionType::get(builder.getDoubleTy(), callback_param_types, false); + llvm::PointerType *callback_function_pointer_type = llvm::PointerType::get(callback_function_type, 0); + param_types.push_back(callback_function_pointer_type); + param_types.push_back(builder.getVoidTy()->getPointerTo()); } llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false); function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, name_in.c_str(), &module); @@ -105,14 +114,18 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { //------------------------------------------------------------------------- llvm::Value *get_param(size_t idx) { - if (!use_array) { + if (pass_params == PassParams::SEPARATE) { assert(idx < params.size()); return params[idx]; + } else if (pass_params == PassParams::ARRAY) { + assert(params.size() == 1); + llvm::Value *param_array = params[0]; + llvm::Value *addr = builder.CreateGEP(param_array, builder.getInt64(idx)); + return builder.CreateLoad(addr); } - assert(params.size() == 1); - llvm::Value *param_array = params[0]; - llvm::Value *addr = builder.CreateGEP(param_array, builder.getInt64(idx)); - return builder.CreateLoad(addr); + assert(pass_params == PassParams::LAZY); + assert(params.size() == 2); + return builder.CreateCall2(params[0], params[1], builder.getInt64(idx), "resolve_param"); } //------------------------------------------------------------------------- @@ -167,7 +180,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { llvm::PointerType *function_pointer_type = llvm::PointerType::get(function_type, 0); llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), function_pointer_type, "inject_eval"); llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getVoidTy()->getPointerTo(), "inject_ctx"); - push(builder.CreateCall2(eval_fun, ctx, function->arg_begin(), "call_eval")); + assert(pass_params == PassParams::ARRAY); + push(builder.CreateCall2(eval_fun, ctx, params[0], "call_eval")); return true; } @@ -178,7 +192,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { push_double(node.get_const_value()); return false; } - if (!inside_forest && use_array && node.is_forest()) { + if (!inside_forest && (pass_params == PassParams::ARRAY) && node.is_forest()) { if (try_optimize_forest(node)) { return false; } @@ -591,13 +605,13 @@ LLVMWrapper::LLVMWrapper(LLVMWrapper &&rhs) } void * -LLVMWrapper::compile_function(size_t num_params, bool use_array, const Node &root, +LLVMWrapper::compile_function(size_t num_params, PassParams pass_params, const Node &root, const gbdt::Optimize::Chain &forest_optimizers) { std::lock_guard guard(_global_llvm_lock); FunctionBuilder builder(*_engine, *_context, *_module, vespalib::make_string("f%zu", ++_num_functions), - num_params, use_array, + num_params, pass_params, forest_optimizers, _forests, _plugin_state); builder.build_root(root); return builder.compile(); @@ -609,7 +623,7 @@ LLVMWrapper::compile_forest_fragment(const std::vector &fragment) std::lock_guard guard(_global_llvm_lock); FunctionBuilder builder(*_engine, *_context, *_module, vespalib::make_string("f%zu", ++_num_functions), - 0, true, + 0, PassParams::ARRAY, gbdt::Optimize::none, _forests, _plugin_state); builder.build_forest_fragment(fragment); return builder.compile(); diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h index 6eb7ddf2f2b..d81f313a392 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h @@ -55,7 +55,7 @@ private: public: LLVMWrapper(); LLVMWrapper(LLVMWrapper &&rhs); - void *compile_function(size_t num_params, bool use_array, const nodes::Node &root, + void *compile_function(size_t num_params, PassParams pass_params, const nodes::Node &root, const gbdt::Optimize::Chain &forest_optimizers); void *compile_forest_fragment(const std::vector &fragment); const std::vector &get_forests() const { return _forests; } -- cgit v1.2.3