aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
blob: adc195f9c55edd9620f7ffc1337e64736885f18f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vespa/eval/eval/function.h>
#include <vespa/eval/eval/gbdt.h>

#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <mutex>

extern "C" {
    double vespalib_eval_ldexp(double a, double b);
    double vespalib_eval_min(double a, double b);
    double vespalib_eval_max(double a, double b);
    double vespalib_eval_isnan(double a);
    double vespalib_eval_approx(double a, double b);
    double vespalib_eval_relu(double a);
    double vespalib_eval_sigmoid(double a);
    double vespalib_eval_elu(double a);
    double vespalib_eval_bit(double a, double b);
    double vespalib_eval_hamming(double a, double b);
};

namespace vespalib::eval {

/**
 * Simple interface used to track and clean up custom state. This is
 * typically used to destruct native objects that are invoked from
 * within the generated machine code as part of evaluation. An example
 * is that large set membership checks against constant values will be
 * transformed into lookups in a pre-generated hash table.
 **/
struct PluginState {
    using UP = std::unique_ptr<PluginState>;
    virtual ~PluginState() {}
};

/**
 * Stuff related to LLVM code generation is wrapped in this
 * class. This is mostly used by the CompiledFunction class.
 **/
class LLVMWrapper
{
private:
    std::unique_ptr<llvm::LLVMContext>     _context;
    std::unique_ptr<llvm::Module>          _module;
    std::unique_ptr<llvm::ExecutionEngine> _engine;
    std::vector<llvm::Function*>           _functions;
    std::vector<gbdt::Forest::UP>          _forests;
    std::vector<PluginState::UP>           _plugin_state;

    void compile(llvm::raw_ostream * dumpStream);
public:
    LLVMWrapper();
    LLVMWrapper(LLVMWrapper &&rhs) = default;

    size_t make_function(size_t num_params, PassParams pass_params, const nodes::Node &root,
                         const gbdt::Optimize::Chain &forest_optimizers);
    size_t make_forest_fragment(size_t num_params, const std::vector<const nodes::Node *> &fragment);
    const std::vector<gbdt::Forest::UP> &get_forests() const { return _forests; }
    void compile(llvm::raw_ostream & dumpStream) { compile(&dumpStream); }
    void compile() { compile(nullptr); }
    void *get_function_address(size_t function_id);
    ~LLVMWrapper();
};

}