summaryrefslogtreecommitdiffstats
path: root/messagebus
diff options
context:
space:
mode:
authorTor Brede Vekterli <vekterli@yahooinc.com>2022-03-14 11:43:59 +0000
committerTor Brede Vekterli <vekterli@yahooinc.com>2022-03-14 15:50:37 +0000
commit5562f675ea89145be65488026e5fc929a398cfeb (patch)
treed283611ae6d44f05a37c22754fadc94fe512b0b2 /messagebus
parent0a8caeae3ea37d513a649a156e544bd59ab0545a (diff)
Make SourceSession pending count atomic to avoid data races
Could have used the existing mutex, but this value is polled frequently by visitor threads so avoiding having to take a lock every time makes sense.
Diffstat (limited to 'messagebus')
-rw-r--r--messagebus/src/vespa/messagebus/sourcesession.cpp22
-rw-r--r--messagebus/src/vespa/messagebus/sourcesession.h23
2 files changed, 25 insertions, 20 deletions
diff --git a/messagebus/src/vespa/messagebus/sourcesession.cpp b/messagebus/src/vespa/messagebus/sourcesession.cpp
index d4440b60895..0691e0c07f9 100644
--- a/messagebus/src/vespa/messagebus/sourcesession.cpp
+++ b/messagebus/src/vespa/messagebus/sourcesession.cpp
@@ -9,7 +9,6 @@
using vespalib::make_string;
-
namespace mbus {
SourceSession::SourceSession(MessageBus &mbus, const SourceSessionParams &params)
@@ -81,16 +80,18 @@ SourceSession::send(Message::UP msg)
if (_closed) {
return Result(Error(ErrorCode::SEND_QUEUE_CLOSED, "Source session is closed."), std::move(msg));
}
- if (_throttlePolicy && !_throttlePolicy->canSend(*msg, _pendingCount)) {
+ my_pending_count = getPendingCount();
+ if (_throttlePolicy && !_throttlePolicy->canSend(*msg, my_pending_count)) {
return Result(Error(ErrorCode::SEND_QUEUE_FULL,
- make_string("Too much pending data (%d messages).", _pendingCount)),
+ make_string("Too much pending data (%d messages).", my_pending_count)),
std::move(msg));
}
msg->pushHandler(_replyHandler);
if (_throttlePolicy) {
_throttlePolicy->processMessage(*msg);
}
- my_pending_count = ++_pendingCount;
+ ++my_pending_count;
+ _pendingCount.store(my_pending_count, std::memory_order_relaxed);
}
if (msg->getTrace().shouldTrace(TraceLevel::COMPONENT)) {
msg->getTrace().trace(TraceLevel::COMPONENT,
@@ -109,13 +110,14 @@ SourceSession::handleReply(Reply::UP reply)
uint32_t my_pending_count = 0;
{
std::lock_guard guard(_lock);
- assert(_pendingCount > 0);
- --_pendingCount;
+ my_pending_count = getPendingCount();
+ assert(my_pending_count > 0);
+ --my_pending_count;
+ _pendingCount.store(my_pending_count, std::memory_order_relaxed);
if (_throttlePolicy) {
_throttlePolicy->processReply(*reply);
}
- my_pending_count = _pendingCount;
- done = (_closed && _pendingCount == 0);
+ done = (_closed && my_pending_count == 0);
}
if (reply->getTrace().shouldTrace(TraceLevel::COMPONENT)) {
reply->getTrace().trace(TraceLevel::COMPONENT,
@@ -126,7 +128,7 @@ SourceSession::handleReply(Reply::UP reply)
if (done) {
{
std::lock_guard guard(_lock);
- assert(_pendingCount == 0);
+ assert(getPendingCount() == 0);
assert(_closed);
_done = true;
}
@@ -139,7 +141,7 @@ SourceSession::close()
{
std::unique_lock guard(_lock);
_closed = true;
- if (_pendingCount == 0) {
+ if (getPendingCount() == 0) {
_done = true;
}
while (!_done) {
diff --git a/messagebus/src/vespa/messagebus/sourcesession.h b/messagebus/src/vespa/messagebus/sourcesession.h
index f75f41e2d20..364533ece17 100644
--- a/messagebus/src/vespa/messagebus/sourcesession.h
+++ b/messagebus/src/vespa/messagebus/sourcesession.h
@@ -5,6 +5,7 @@
#include "result.h"
#include "sequencer.h"
#include "sourcesessionparams.h"
+#include <atomic>
#include <condition_variable>
namespace mbus {
@@ -23,15 +24,15 @@ private:
std::mutex _lock;
std::condition_variable _cond;
- MessageBus &_mbus;
- ReplyGate *_gate;
- Sequencer _sequencer;
- IReplyHandler &_replyHandler;
- IThrottlePolicy::SP _throttlePolicy;
- duration _timeout;
- uint32_t _pendingCount;
- bool _closed;
- bool _done;
+ MessageBus &_mbus;
+ ReplyGate *_gate;
+ Sequencer _sequencer;
+ IReplyHandler &_replyHandler;
+ IThrottlePolicy::SP _throttlePolicy;
+ duration _timeout;
+ std::atomic<uint32_t> _pendingCount;
+ bool _closed;
+ bool _done;
private:
/**
@@ -113,7 +114,9 @@ public:
*
* @return The pending count.
*/
- uint32_t getPendingCount() const { return _pendingCount; }
+ [[nodiscard]] uint32_t getPendingCount() const noexcept {
+ return _pendingCount.load(std::memory_order_relaxed);
+ }
/**
* Sets the number of seconds a message can be attempted sent until it times out.