diff options
author | Tor Brede Vekterli <vekterli@yahooinc.com> | 2022-03-14 11:43:59 +0000 |
---|---|---|
committer | Tor Brede Vekterli <vekterli@yahooinc.com> | 2022-03-14 15:50:37 +0000 |
commit | 5562f675ea89145be65488026e5fc929a398cfeb (patch) | |
tree | d283611ae6d44f05a37c22754fadc94fe512b0b2 /messagebus/src | |
parent | 0a8caeae3ea37d513a649a156e544bd59ab0545a (diff) |
Make SourceSession pending count atomic to avoid data races
Could have used the existing mutex, but this value is polled frequently
by visitor threads so avoiding having to take a lock every time makes sense.
Diffstat (limited to 'messagebus/src')
-rw-r--r-- | messagebus/src/vespa/messagebus/sourcesession.cpp | 22 | ||||
-rw-r--r-- | messagebus/src/vespa/messagebus/sourcesession.h | 23 |
2 files changed, 25 insertions, 20 deletions
diff --git a/messagebus/src/vespa/messagebus/sourcesession.cpp b/messagebus/src/vespa/messagebus/sourcesession.cpp index d4440b60895..0691e0c07f9 100644 --- a/messagebus/src/vespa/messagebus/sourcesession.cpp +++ b/messagebus/src/vespa/messagebus/sourcesession.cpp @@ -9,7 +9,6 @@ using vespalib::make_string; - namespace mbus { SourceSession::SourceSession(MessageBus &mbus, const SourceSessionParams ¶ms) @@ -81,16 +80,18 @@ SourceSession::send(Message::UP msg) if (_closed) { return Result(Error(ErrorCode::SEND_QUEUE_CLOSED, "Source session is closed."), std::move(msg)); } - if (_throttlePolicy && !_throttlePolicy->canSend(*msg, _pendingCount)) { + my_pending_count = getPendingCount(); + if (_throttlePolicy && !_throttlePolicy->canSend(*msg, my_pending_count)) { return Result(Error(ErrorCode::SEND_QUEUE_FULL, - make_string("Too much pending data (%d messages).", _pendingCount)), + make_string("Too much pending data (%d messages).", my_pending_count)), std::move(msg)); } msg->pushHandler(_replyHandler); if (_throttlePolicy) { _throttlePolicy->processMessage(*msg); } - my_pending_count = ++_pendingCount; + ++my_pending_count; + _pendingCount.store(my_pending_count, std::memory_order_relaxed); } if (msg->getTrace().shouldTrace(TraceLevel::COMPONENT)) { msg->getTrace().trace(TraceLevel::COMPONENT, @@ -109,13 +110,14 @@ SourceSession::handleReply(Reply::UP reply) uint32_t my_pending_count = 0; { std::lock_guard guard(_lock); - assert(_pendingCount > 0); - --_pendingCount; + my_pending_count = getPendingCount(); + assert(my_pending_count > 0); + --my_pending_count; + _pendingCount.store(my_pending_count, std::memory_order_relaxed); if (_throttlePolicy) { _throttlePolicy->processReply(*reply); } - my_pending_count = _pendingCount; - done = (_closed && _pendingCount == 0); + done = (_closed && my_pending_count == 0); } if (reply->getTrace().shouldTrace(TraceLevel::COMPONENT)) { reply->getTrace().trace(TraceLevel::COMPONENT, @@ -126,7 +128,7 @@ SourceSession::handleReply(Reply::UP reply) if (done) { { std::lock_guard guard(_lock); - assert(_pendingCount == 0); + assert(getPendingCount() == 0); assert(_closed); _done = true; } @@ -139,7 +141,7 @@ SourceSession::close() { std::unique_lock guard(_lock); _closed = true; - if (_pendingCount == 0) { + if (getPendingCount() == 0) { _done = true; } while (!_done) { diff --git a/messagebus/src/vespa/messagebus/sourcesession.h b/messagebus/src/vespa/messagebus/sourcesession.h index f75f41e2d20..364533ece17 100644 --- a/messagebus/src/vespa/messagebus/sourcesession.h +++ b/messagebus/src/vespa/messagebus/sourcesession.h @@ -5,6 +5,7 @@ #include "result.h" #include "sequencer.h" #include "sourcesessionparams.h" +#include <atomic> #include <condition_variable> namespace mbus { @@ -23,15 +24,15 @@ private: std::mutex _lock; std::condition_variable _cond; - MessageBus &_mbus; - ReplyGate *_gate; - Sequencer _sequencer; - IReplyHandler &_replyHandler; - IThrottlePolicy::SP _throttlePolicy; - duration _timeout; - uint32_t _pendingCount; - bool _closed; - bool _done; + MessageBus &_mbus; + ReplyGate *_gate; + Sequencer _sequencer; + IReplyHandler &_replyHandler; + IThrottlePolicy::SP _throttlePolicy; + duration _timeout; + std::atomic<uint32_t> _pendingCount; + bool _closed; + bool _done; private: /** @@ -113,7 +114,9 @@ public: * * @return The pending count. */ - uint32_t getPendingCount() const { return _pendingCount; } + [[nodiscard]] uint32_t getPendingCount() const noexcept { + return _pendingCount.load(std::memory_order_relaxed); + } /** * Sets the number of seconds a message can be attempted sent until it times out. |