aboutsummaryrefslogtreecommitdiffstats
path: root/messagebus/src/vespa/messagebus/sourcesession.cpp
blob: 2e687329c61f83f4f4fdfaf31d4a699fa800197b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "sourcesession.h"
#include "errorcode.h"
#include "messagebus.h"
#include "replygate.h"
#include "tracelevel.h"
#include <vespa/messagebus/routing/routingtable.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <cassert>

using vespalib::make_string;
using vespalib::make_ref_counted;

namespace mbus {

SourceSession::SourceSession(MessageBus &mbus, const SourceSessionParams &params)
    : _lock(),
      _mbus(mbus),
      _gate(make_ref_counted<ReplyGate>(_mbus)),
      _sequencer(*_gate),
      _replyHandler(params.getReplyHandler()),
      _throttlePolicy(params.getThrottlePolicy()),
      _timeout(params.getTimeout()),
      _pendingCount(0),
      _closed(false),
      _done(false)
{
    assert(params.hasReplyHandler());
}

SourceSession::~SourceSession()
{
    // Ensure that no more replies propagate from mbus.
    _gate->close();
    _mbus.sync();
}

Result
SourceSession::send(Message::UP msg, const string &routeName, bool parseIfNotFound)
{
    bool found = false;
    RoutingTable::SP rt = _mbus.getRoutingTable(msg->getProtocol());
    if (rt) {
        const Route *route = rt->getRoute(routeName);
        if (route != nullptr) {
            msg->setRoute(*route);
            found = true;
        } else if (!parseIfNotFound) {
            string str = make_string("Route '%s' not found.", routeName.c_str());
            return Result(Error(ErrorCode::ILLEGAL_ROUTE, str), std::move(msg));
        }
    } else if (!parseIfNotFound) {
        string str = make_string("No routing table available for protocol '%s'.", msg->getProtocol().c_str());
        return Result(Error(ErrorCode::ILLEGAL_ROUTE, str), std::move(msg));
    }
    if (!found) {
        msg->setRoute(Route::parse(routeName));
    }
    return send(std::move(msg));
}

Result
SourceSession::send(Message::UP msg, const Route &route)
{
    msg->setRoute(route);
    return send(std::move(msg));
}

Result
SourceSession::send(Message::UP msg)
{
    msg->setTimeReceivedNow();
    if (msg->getTimeRemaining() == 0ms) {
        msg->setTimeRemaining(_timeout);
    }
    uint32_t my_pending_count = 0;
    {
        std::lock_guard guard(_lock);
        if (_closed) {
            return Result(Error(ErrorCode::SEND_QUEUE_CLOSED, "Source session is closed."), std::move(msg));
        }
        my_pending_count = getPendingCount();
        if (_throttlePolicy && !_throttlePolicy->canSend(*msg, my_pending_count)) {
            return Result(Error(ErrorCode::SEND_QUEUE_FULL,
                                make_string("Too much pending data (%d messages).", my_pending_count)),
                          std::move(msg));
        }
        msg->pushHandler(_replyHandler);
        if (_throttlePolicy) {
            _throttlePolicy->processMessage(*msg);
        }
        ++my_pending_count;
        _pendingCount.store(my_pending_count, std::memory_order_relaxed);
    }
    if (msg->getTrace().shouldTrace(TraceLevel::COMPONENT)) {
        msg->getTrace().trace(TraceLevel::COMPONENT,
                              make_string("Source session accepted a %d byte message. %d message(s) now pending.",
                                          msg->getApproxSize(), my_pending_count));
    }
    msg->pushHandler(*this);
    _sequencer.handleMessage(std::move(msg));
    return Result();
}

void
SourceSession::handleReply(Reply::UP reply)
{
    bool done;
    uint32_t my_pending_count = 0;
    {
        std::lock_guard guard(_lock);
        my_pending_count = getPendingCount();
        assert(my_pending_count > 0);
        --my_pending_count;
        _pendingCount.store(my_pending_count, std::memory_order_relaxed);
        if (_throttlePolicy) {
            _throttlePolicy->processReply(*reply);
        }
        done = (_closed && my_pending_count == 0);
    }
    if (reply->getTrace().shouldTrace(TraceLevel::COMPONENT)) {
        reply->getTrace().trace(TraceLevel::COMPONENT,
                                make_string("Source session received reply. %d message(s) now pending.", my_pending_count));
    }
    IReplyHandler &handler = reply->getCallStack().pop(*reply);
    handler.handleReply(std::move(reply));
    if (done) {
        {
            std::lock_guard guard(_lock);
            assert(getPendingCount() == 0);
            assert(_closed);
            _done = true;
        }
        _cond.notify_all();
    }
}

void
SourceSession::close()
{
    std::unique_lock guard(_lock);
    _closed = true;
    if (getPendingCount() == 0) {
        _done = true;
    }
    while (!_done) {
        _cond.wait(guard);
    }
}

SourceSession &
SourceSession::setTimeout(duration timeout)
{
    std::lock_guard guard(_lock);
    _timeout = timeout;
    return *this;
}

} // namespace mbus