summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-17 15:52:17 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-17 15:52:17 +0000
commit93919968eef0d1ab16a6de77de0feec7d698dfb6 (patch)
treeec4484afa06fe15331ed84b11f58923e214b68b6 /eval
parent1d8791b9f5f4998fd09a37b2c70d4fd53e9359d7 (diff)
added support for lazy parameters to compiled functions
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/compile_cache/compile_cache_test.cpp3
-rw-r--r--eval/src/tests/eval/compiled_function/compiled_function_test.cpp11
-rw-r--r--eval/src/vespa/eval/eval/function.h2
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp3
-rw-r--r--eval/src/vespa/eval/eval/key_gen.h2
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp2
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.h7
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp44
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.h2
9 files changed, 54 insertions, 22 deletions
diff --git a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp
index ffcb623c00e..dad8973fb63 100644
--- a/eval/src/tests/eval/compile_cache/compile_cache_test.cpp
+++ b/eval/src/tests/eval/compile_cache/compile_cache_test.cpp
@@ -72,7 +72,8 @@ struct CheckKeys : test::EvalSpec::EvalTest {
Function function = Function::parse(param_names, expression);
if (!CompiledFunction::detect_issues(function)) {
if (check_key(gen_key(function, PassParams::ARRAY)) ||
- check_key(gen_key(function, PassParams::SEPARATE)))
+ check_key(gen_key(function, PassParams::SEPARATE)) ||
+ check_key(gen_key(function, PassParams::LAZY)))
{
failed = true;
fprintf(stderr, "key collision for: %s\n", expression.c_str());
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
index 953fe74ae6b..6a1b1f587e3 100644
--- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -35,6 +35,17 @@ TEST("require that array parameter passing works") {
EXPECT_EQUAL(45.0, arr_fun(&std::vector<double>({9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0})[0]));
}
+double my_resolve(void *ctx, size_t idx) { return ((double *)ctx)[idx]; }
+
+TEST("require that lazy parameter passing works") {
+ CompiledFunction lazy_cf(Function::parse(params_10, expr_10), PassParams::LAZY);
+ auto lazy_fun = lazy_cf.get_lazy_function();
+ EXPECT_EQUAL(10.0, lazy_fun(my_resolve, &std::vector<double>({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0})[0]));
+ EXPECT_EQUAL(50.0, lazy_fun(my_resolve, &std::vector<double>({5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0})[0]));
+ EXPECT_EQUAL(45.0, lazy_fun(my_resolve, &std::vector<double>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0})[0]));
+ EXPECT_EQUAL(45.0, lazy_fun(my_resolve, &std::vector<double>({9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0})[0]));
+}
+
//-----------------------------------------------------------------------------
std::vector<vespalib::string> unsupported = {
diff --git a/eval/src/vespa/eval/eval/function.h b/eval/src/vespa/eval/eval/function.h
index 35a89ce6512..c1141671e49 100644
--- a/eval/src/vespa/eval/eval/function.h
+++ b/eval/src/vespa/eval/eval/function.h
@@ -14,7 +14,7 @@
namespace vespalib {
namespace eval {
-enum class PassParams { SEPARATE, ARRAY };
+enum class PassParams : uint8_t { SEPARATE, ARRAY, LAZY };
/**
* Interface used to perform custom symbol extraction. This is
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index 3d0f1f67e29..861bcd9b904 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -22,7 +22,6 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void add_int(int value) { key.append(&value, sizeof(value)); }
void add_hash(uint32_t value) { key.append(&value, sizeof(value)); }
void add_byte(uint8_t value) { key.append(&value, sizeof(value)); }
- void add_bool(bool value) { key.push_back(value ? '1' : '0'); }
// visit
virtual void visit(const Number &node) { add_byte( 1); add_double(node.value()); }
@@ -92,7 +91,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
vespalib::string gen_key(const Function &function, PassParams pass_params)
{
KeyGen key_gen;
- key_gen.add_bool(pass_params == PassParams::ARRAY);
+ key_gen.add_byte(uint8_t(pass_params));
key_gen.add_size(function.num_params());
function.root().traverse(key_gen);
return key_gen.key;
diff --git a/eval/src/vespa/eval/eval/key_gen.h b/eval/src/vespa/eval/eval/key_gen.h
index c8479b1b457..d8dcf4f2f04 100644
--- a/eval/src/vespa/eval/eval/key_gen.h
+++ b/eval/src/vespa/eval/eval/key_gen.h
@@ -8,7 +8,7 @@ namespace vespalib {
namespace eval {
class Function;
-enum class PassParams;
+enum class PassParams : uint8_t;
/**
* Function used to generate a binary key that may be used to query
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index 694613ea354..8bf9b08a1bd 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -25,7 +25,7 @@ CompiledFunction::CompiledFunction(const Function &function_in, PassParams pass_
_pass_params(pass_params_in)
{
_address = _llvm_wrapper.compile_function(function_in.num_params(),
- (_pass_params == PassParams::ARRAY),
+ _pass_params,
function_in.root(),
forest_optimizers);
}
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.h b/eval/src/vespa/eval/eval/llvm/compiled_function.h
index 41754cfb413..7f9e7c7657c 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.h
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.h
@@ -26,6 +26,9 @@ public:
using array_function = double (*)(const double *);
+ using resolve_function = double (*)(void *ctx, size_t idx);
+ using lazy_function = double (*)(resolve_function, void *ctx);
+
private:
LLVMWrapper _llvm_wrapper;
void *_address;
@@ -51,6 +54,10 @@ public:
assert(_pass_params == PassParams::ARRAY);
return ((array_function)_address);
}
+ lazy_function get_lazy_function() const {
+ assert(_pass_params == PassParams::LAZY);
+ return ((lazy_function)_address);
+ }
const std::vector<gbdt::Forest::UP> &get_forests() const {
return _llvm_wrapper.get_forests();
}
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index e917261240f..2e1627f2b97 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -55,7 +55,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
std::vector<llvm::Value*> values;
std::vector<llvm::Value*> let_values;
llvm::Function *function;
- bool use_array;
+ PassParams pass_params;
bool inside_forest;
const Node *forest_end;
const gbdt::Optimize::Chain &forest_optimizers;
@@ -67,7 +67,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Module &module_in,
const vespalib::string &name_in,
size_t num_params_in,
- bool use_array_in,
+ PassParams pass_params_in,
const gbdt::Optimize::Chain &forest_optimizers_in,
std::vector<gbdt::Forest::UP> &forests_out,
std::vector<PluginState::UP> &plugin_state_out)
@@ -79,7 +79,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
values(),
let_values(),
function(nullptr),
- use_array(use_array_in),
+ pass_params(pass_params_in),
inside_forest(false),
forest_end(nullptr),
forest_optimizers(forest_optimizers_in),
@@ -87,10 +87,19 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
plugin_state(plugin_state_out)
{
std::vector<llvm::Type*> param_types;
- if (use_array_in) {
+ if (pass_params == PassParams::SEPARATE) {
+ param_types.resize(num_params_in, builder.getDoubleTy());
+ } else if (pass_params == PassParams::ARRAY) {
param_types.push_back(builder.getDoubleTy()->getPointerTo());
} else {
- param_types.resize(num_params_in, builder.getDoubleTy());
+ assert(pass_params == PassParams::LAZY);
+ std::vector<llvm::Type*> callback_param_types;
+ callback_param_types.push_back(builder.getVoidTy()->getPointerTo());
+ callback_param_types.push_back(builder.getInt64Ty());
+ llvm::FunctionType *callback_function_type = llvm::FunctionType::get(builder.getDoubleTy(), callback_param_types, false);
+ llvm::PointerType *callback_function_pointer_type = llvm::PointerType::get(callback_function_type, 0);
+ param_types.push_back(callback_function_pointer_type);
+ param_types.push_back(builder.getVoidTy()->getPointerTo());
}
llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, name_in.c_str(), &module);
@@ -105,14 +114,18 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
//-------------------------------------------------------------------------
llvm::Value *get_param(size_t idx) {
- if (!use_array) {
+ if (pass_params == PassParams::SEPARATE) {
assert(idx < params.size());
return params[idx];
+ } else if (pass_params == PassParams::ARRAY) {
+ assert(params.size() == 1);
+ llvm::Value *param_array = params[0];
+ llvm::Value *addr = builder.CreateGEP(param_array, builder.getInt64(idx));
+ return builder.CreateLoad(addr);
}
- assert(params.size() == 1);
- llvm::Value *param_array = params[0];
- llvm::Value *addr = builder.CreateGEP(param_array, builder.getInt64(idx));
- return builder.CreateLoad(addr);
+ assert(pass_params == PassParams::LAZY);
+ assert(params.size() == 2);
+ return builder.CreateCall2(params[0], params[1], builder.getInt64(idx), "resolve_param");
}
//-------------------------------------------------------------------------
@@ -167,7 +180,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::PointerType *function_pointer_type = llvm::PointerType::get(function_type, 0);
llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), function_pointer_type, "inject_eval");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getVoidTy()->getPointerTo(), "inject_ctx");
- push(builder.CreateCall2(eval_fun, ctx, function->arg_begin(), "call_eval"));
+ assert(pass_params == PassParams::ARRAY);
+ push(builder.CreateCall2(eval_fun, ctx, params[0], "call_eval"));
return true;
}
@@ -178,7 +192,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
push_double(node.get_const_value());
return false;
}
- if (!inside_forest && use_array && node.is_forest()) {
+ if (!inside_forest && (pass_params == PassParams::ARRAY) && node.is_forest()) {
if (try_optimize_forest(node)) {
return false;
}
@@ -591,13 +605,13 @@ LLVMWrapper::LLVMWrapper(LLVMWrapper &&rhs)
}
void *
-LLVMWrapper::compile_function(size_t num_params, bool use_array, const Node &root,
+LLVMWrapper::compile_function(size_t num_params, PassParams pass_params, const Node &root,
const gbdt::Optimize::Chain &forest_optimizers)
{
std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
FunctionBuilder builder(*_engine, *_context, *_module,
vespalib::make_string("f%zu", ++_num_functions),
- num_params, use_array,
+ num_params, pass_params,
forest_optimizers, _forests, _plugin_state);
builder.build_root(root);
return builder.compile();
@@ -609,7 +623,7 @@ LLVMWrapper::compile_forest_fragment(const std::vector<const Node *> &fragment)
std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock);
FunctionBuilder builder(*_engine, *_context, *_module,
vespalib::make_string("f%zu", ++_num_functions),
- 0, true,
+ 0, PassParams::ARRAY,
gbdt::Optimize::none, _forests, _plugin_state);
builder.build_forest_fragment(fragment);
return builder.compile();
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
index 6eb7ddf2f2b..d81f313a392 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
@@ -55,7 +55,7 @@ private:
public:
LLVMWrapper();
LLVMWrapper(LLVMWrapper &&rhs);
- void *compile_function(size_t num_params, bool use_array, const nodes::Node &root,
+ void *compile_function(size_t num_params, PassParams pass_params, const nodes::Node &root,
const gbdt::Optimize::Chain &forest_optimizers);
void *compile_forest_fragment(const std::vector<const nodes::Node *> &fragment);
const std::vector<gbdt::Forest::UP> &get_forests() const { return _forests; }