summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp61
1 files changed, 23 insertions, 38 deletions
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index a3e4e3df7fa..158bc91dd6a 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -16,9 +16,7 @@
#include <llvm/IR/DataLayout.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
-#if LLVM_VERSION_MAJOR > 9
#include <llvm/Support/ManagedStatic.h>
-#endif
#include <vespa/eval/eval/check_type.h>
#include <vespa/vespalib/stllike/hash_set.h>
#include <vespa/vespalib/util/approx.h>
@@ -106,39 +104,35 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
return llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
}
- llvm::PointerType *make_eval_forest_funptr_t() {
+ llvm::FunctionType *make_eval_forest_fun_t() {
std::vector<llvm::Type*> param_types;
param_types.push_back(builder.getInt8Ty()->getPointerTo());
param_types.push_back(builder.getDoubleTy()->getPointerTo());
- llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
- return llvm::PointerType::get(function_type, 0);
+ return llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
}
- llvm::PointerType *make_resolve_param_funptr_t() {
+ llvm::FunctionType *make_resolve_param_fun_t() {
std::vector<llvm::Type*> param_types;
param_types.push_back(builder.getInt8Ty()->getPointerTo());
param_types.push_back(builder.getInt64Ty());
- llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
- return llvm::PointerType::get(function_type, 0);
+ return llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
}
- llvm::PointerType *make_eval_forest_proxy_funptr_t() {
+ llvm::FunctionType *make_eval_forest_proxy_fun_t() {
std::vector<llvm::Type*> param_types;
- param_types.push_back(make_eval_forest_funptr_t());
+ param_types.push_back(llvm::PointerType::get(make_eval_forest_fun_t(), 0));
param_types.push_back(builder.getInt8Ty()->getPointerTo());
- param_types.push_back(make_resolve_param_funptr_t());
+ param_types.push_back(llvm::PointerType::get(make_resolve_param_fun_t(), 0));
param_types.push_back(builder.getInt8Ty()->getPointerTo());
param_types.push_back(builder.getInt64Ty());
- llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
- return llvm::PointerType::get(function_type, 0);
+ return llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
}
- llvm::PointerType *make_check_membership_funptr_t() {
+ llvm::FunctionType *make_check_membership_fun_t() {
std::vector<llvm::Type*> param_types;
param_types.push_back(builder.getInt8Ty()->getPointerTo());
param_types.push_back(builder.getDoubleTy());
- llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getInt1Ty(), param_types, false);
- return llvm::PointerType::get(function_type, 0);
+ return llvm::FunctionType::get(builder.getInt1Ty(), param_types, false);
}
FunctionBuilder(llvm::LLVMContext &context_in,
@@ -170,7 +164,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
param_types.push_back(builder.getDoubleTy()->getPointerTo());
} else {
assert(pass_params == PassParams::LAZY);
- param_types.push_back(make_resolve_param_funptr_t());
+ param_types.push_back(llvm::PointerType::get(make_resolve_param_fun_t(), 0));
param_types.push_back(builder.getInt8Ty()->getPointerTo());
}
llvm::FunctionType *function_type = llvm::FunctionType::get(builder.getDoubleTy(), param_types, false);
@@ -194,12 +188,12 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
} else if (pass_params == PassParams::ARRAY) {
assert(params.size() == 1);
llvm::Value *param_array = params[0];
- llvm::Value *addr = builder.CreateGEP(param_array->getType()->getScalarType()->getPointerElementType(), param_array, builder.getInt64(idx));
- return builder.CreateLoad(addr->getType()->getPointerElementType(), addr);
+ llvm::Value *addr = builder.CreateGEP(builder.getDoubleTy(), param_array, builder.getInt64(idx));
+ return builder.CreateLoad(builder.getDoubleTy(), addr);
}
assert(pass_params == PassParams::LAZY);
assert(params.size() == 2);
- return builder.CreateCall(llvm::cast<llvm::FunctionType>(params[0]->getType()->getPointerElementType()),
+ return builder.CreateCall(make_resolve_param_fun_t(),
params[0], {params[1], builder.getInt64(idx)}, "resolve_param");
}
@@ -248,17 +242,19 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
forests.push_back(std::move(optimize_result.forest));
void *eval_ptr = (void *) optimize_result.eval;
gbdt::Forest *forest = forests.back().get();
- llvm::PointerType *eval_funptr_t = make_eval_forest_funptr_t();
+ llvm::FunctionType* eval_fun_t = make_eval_forest_fun_t();
+ llvm::PointerType *eval_funptr_t = llvm::PointerType::get(eval_fun_t, 0);
llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), eval_funptr_t, "inject_eval");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getInt8Ty()->getPointerTo(), "inject_ctx");
if (pass_params == PassParams::ARRAY) {
- push(builder.CreateCall(llvm::cast<llvm::FunctionType>(eval_fun->getType()->getPointerElementType()),
+ push(builder.CreateCall(eval_fun_t,
eval_fun, {ctx, params[0]}, "call_eval"));
} else {
assert(pass_params == PassParams::LAZY);
- llvm::PointerType *proxy_funptr_t = make_eval_forest_proxy_funptr_t();
+ llvm::FunctionType* proxy_fun_t = make_eval_forest_proxy_fun_t();
+ llvm::PointerType *proxy_funptr_t = llvm::PointerType::get(proxy_fun_t, 0);
llvm::Value *proxy_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)vespalib_eval_forest_proxy), proxy_funptr_t, "inject_eval_proxy");
- push(builder.CreateCall(llvm::cast<llvm::FunctionType>(proxy_fun->getType()->getPointerElementType()),
+ push(builder.CreateCall(proxy_fun_t,
proxy_fun, {eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params)}));
}
return true;
@@ -341,7 +337,6 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Value *a = pop_double();
push(builder.CreateCall(fun, a));
}
-#if LLVM_VERSION_MAJOR >= 9
void make_call_1(llvm::FunctionCallee fun) {
if (!fun || fun.getFunctionType()->getNumParams() != 1) {
return make_error(1);
@@ -349,16 +344,11 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Value *a = pop_double();
push(builder.CreateCall(fun, a));
}
-#endif
void make_call_1(const llvm::Intrinsic::ID &id) {
make_call_1(llvm::Intrinsic::getDeclaration(&module, id, builder.getDoubleTy()));
}
void make_call_1(const char *name) {
-#if LLVM_VERSION_MAJOR >= 9
make_call_1(module.getOrInsertFunction(name, make_call_1_fun_t()));
-#else
- make_call_1(llvm::dyn_cast<llvm::Function>(module.getOrInsertFunction(name, make_call_1_fun_t())));
-#endif
}
void make_call_2(llvm::Function *fun) {
@@ -369,7 +359,6 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Value *a = pop_double();
push(builder.CreateCall(fun, {a, b}));
}
-#if LLVM_VERSION_MAJOR >= 9
void make_call_2(llvm::FunctionCallee fun) {
if (!fun || fun.getFunctionType()->getNumParams() != 2) {
return make_error(2);
@@ -378,16 +367,11 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
llvm::Value *a = pop_double();
push(builder.CreateCall(fun, {a, b}));
}
-#endif
void make_call_2(const llvm::Intrinsic::ID &id) {
make_call_2(llvm::Intrinsic::getDeclaration(&module, id, builder.getDoubleTy()));
}
void make_call_2(const char *name) {
-#if LLVM_VERSION_MAJOR >= 9
make_call_2(module.getOrInsertFunction(name, make_call_2_fun_t()));
-#else
- make_call_2(llvm::dyn_cast<llvm::Function>(module.getOrInsertFunction(name, make_call_2_fun_t())));
-#endif
}
//-------------------------------------------------------------------------
@@ -410,10 +394,11 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
plugin_state.emplace_back(new SetMemberHash(item));
void *call_ptr = (void *) SetMemberHash::check_membership;
PluginState *state = plugin_state.back().get();
- llvm::PointerType *funptr_t = make_check_membership_funptr_t();
+ llvm::FunctionType *fun_t = make_check_membership_fun_t();
+ llvm::PointerType *funptr_t = llvm::PointerType::get(fun_t, 0);
llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr");
llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getInt8Ty()->getPointerTo(), "inject_ctx");
- push(builder.CreateCall(llvm::cast<llvm::FunctionType>(call_fun->getType()->getPointerElementType()),
+ push(builder.CreateCall(fun_t,
call_fun, {ctx, lhs}, "call_check_membership"));
} else {
// build explicit code to check all set members