blob: 08470cf57b9962a72ad8c1c16532efbfa1273582 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
#include <vespa/eval/eval/function.h>
#include <vespa/eval/eval/gbdt.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <mutex>
extern "C" {
double vespalib_eval_ldexp(double a, double b);
double vespalib_eval_min(double a, double b);
double vespalib_eval_max(double a, double b);
double vespalib_eval_isnan(double a);
double vespalib_eval_approx(double a, double b);
double vespalib_eval_relu(double a);
double vespalib_eval_sigmoid(double a);
double vespalib_eval_elu(double a);
double vespalib_eval_bit(double a, double b);
double vespalib_eval_hamming(double a, double b);
};
namespace vespalib::eval {
/**
* Simple interface used to track and clean up custom state. This is
* typically used to destruct native objects that are invoked from
* within the generated machine code as part of evaluation. An example
* is that large set membership checks against constant values will be
* transformed into lookups in a pre-generated hash table.
**/
struct PluginState {
using UP = std::unique_ptr<PluginState>;
virtual ~PluginState() {}
};
/**
* Stuff related to LLVM code generation is wrapped in this
* class. This is mostly used by the CompiledFunction class.
**/
class LLVMWrapper
{
private:
std::unique_ptr<llvm::LLVMContext> _context;
std::unique_ptr<llvm::Module> _module;
std::unique_ptr<llvm::ExecutionEngine> _engine;
std::vector<llvm::Function*> _functions;
std::vector<gbdt::Forest::UP> _forests;
std::vector<PluginState::UP> _plugin_state;
void compile(llvm::raw_ostream * dumpStream);
public:
LLVMWrapper();
LLVMWrapper(LLVMWrapper &&rhs) = default;
size_t make_function(size_t num_params, PassParams pass_params, const nodes::Node &root,
const gbdt::Optimize::Chain &forest_optimizers);
size_t make_forest_fragment(size_t num_params, const std::vector<const nodes::Node *> &fragment);
const std::vector<gbdt::Forest::UP> &get_forests() const { return _forests; }
void compile(llvm::raw_ostream & dumpStream) { compile(&dumpStream); }
void compile() { compile(nullptr); }
void *get_function_address(size_t function_id);
~LLVMWrapper();
};
}
|