aboutsummaryrefslogtreecommitdiffstats
path: root/fnet/src/examples/frt/rpc/rpc_callback_server.cpp
blob: c0504f49c2bb6e886cac6b10bb38176b67e5757a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/fnet/frt/supervisor.h>
#include <vespa/fnet/frt/rpcrequest.h>
#include <vespa/fnet/signalshutdown.h>
#include <vespa/fnet/transport.h>

#include <vespa/vespalib/util/signalhandler.h>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>

#include <vespa/log/log.h>
LOG_SETUP("rpc_callback_server");

/**
 * Class keeping track of 'detached' threads in order to wait for
 * their completion on program shutdown. Threads are not actually
 * detached, but perform co-operative auto-joining on completion.
 **/
class AutoJoiner
{
private:
    std::mutex              _lock;
    std::condition_variable _cond;
    bool                    _closed;
    size_t                  _pending;
    std::thread             _thread;
    struct JoinGuard {
        std::thread thread;
        ~JoinGuard() {
            if (thread.joinable()) {
                assert(std::this_thread::get_id() != thread.get_id());
                thread.join();
            }
        }
    };
    void notify_start() {
        std::lock_guard guard(_lock);
        if (!_closed) {
            ++_pending;
        } else {
            throw std::runtime_error("no new threads allowed");
        }
    }
    void notify_done(std::thread thread) {
        JoinGuard join;
        std::unique_lock guard(_lock);
        join.thread = std::move(_thread);
        _thread = std::move(thread);
        if (--_pending == 0 && _closed) {
            _cond.notify_all();
        }
    }
    auto wrap_task(auto task, std::promise<std::thread> &promise) {
        return [future = promise.get_future(), task = std::move(task), &owner = *this]() mutable
               {
                   auto thread = future.get();
                   assert(std::this_thread::get_id() == thread.get_id());
                   task();
                   owner.notify_done(std::move(thread));
               };
    }
public:
    AutoJoiner() : _lock(), _cond(), _closed(false), _pending(0), _thread() {}
    ~AutoJoiner() { close_and_wait(); }
    void start(auto task) {
        notify_start();
        std::promise<std::thread> promise;
        promise.set_value(std::thread(wrap_task(std::move(task), promise)));
    };
    void close_and_wait() {
        JoinGuard join;
        std::unique_lock guard(_lock);
        _closed = true;
        while (_pending > 0) {
            _cond.wait(guard);
        }
        std::swap(join.thread, _thread);
    }
};

AutoJoiner &auto_joiner() {
    static AutoJoiner obj;
    return obj;
}

struct RPC : public FRT_Invokable
{
    void CallBack(FRT_RPCRequest *req);
    void Init(FRT_Supervisor *s);
};

void do_callback(FRT_RPCRequest *req) {
    FNET_Connection *conn = req->GetConnection();
    FRT_RPCRequest *cb = new FRT_RPCRequest();
    cb->SetMethodName(req->GetParams()->GetValue(0)._string._str);
    FRT_Supervisor::InvokeSync(conn->Owner(), conn, cb, 5.0);
    if(cb->IsError()) {
        printf("[error(%d): %s]\n",
               cb->GetErrorCode(),
               cb->GetErrorMessage());
    }
    cb->internal_subref();
    req->Return();
}

void
RPC::CallBack(FRT_RPCRequest *req)
{
    req->Detach();
    auto_joiner().start([req]{ do_callback(req); });
}

void
RPC::Init(FRT_Supervisor *s)
{
    FRT_ReflectionBuilder rb(s);
    //-------------------------------------------------------------------
    rb.DefineMethod("callBack", "s", "",
                    FRT_METHOD(RPC::CallBack), this);
    //-------------------------------------------------------------------
}


class MyApp
{
public:
    int main(int argc, char **argv);
    ~MyApp() { auto_joiner().close_and_wait(); }
};

int
MyApp::main(int argc, char **argv)
{
    FNET_SignalShutDown::hookSignals();
    if (argc < 2) {
        printf("usage  : rpc_server <listenspec>\n");
        return 1;
    }
    RPC rpc;
    fnet::frt::StandaloneFRT server;
    FRT_Supervisor & supervisor = server.supervisor();
    rpc.Init(&supervisor);
    supervisor.Listen(argv[1]);
    FNET_SignalShutDown ssd(*supervisor.GetTransport());
    server.supervisor().GetTransport()->WaitFinished();
    return 0;
}


int main(int argc, char **argv) {
    vespalib::SignalHandler::PIPE.ignore();
    MyApp myapp;
    return myapp.main(argc, argv);
}