diff options
-rw-r--r-- | vespalib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | vespalib/src/tests/rw_spin_lock/CMakeLists.txt | 8 | ||||
-rw-r--r-- | vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp | 355 | ||||
-rw-r--r-- | vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp | 41 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/test/thread_meets.cpp | 31 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/test/thread_meets.h | 63 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/rendezvous.h | 13 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/rw_spin_lock.h | 189 |
8 files changed, 653 insertions, 48 deletions
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 6d19988b96b..c1f6f2cbbff 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -147,6 +147,7 @@ vespa_define_module( src/tests/require src/tests/runnable_pair src/tests/rusage + src/tests/rw_spin_lock src/tests/sequencedtaskexecutor src/tests/sha1 src/tests/shared_operation_throttler diff --git a/vespalib/src/tests/rw_spin_lock/CMakeLists.txt b/vespalib/src/tests/rw_spin_lock/CMakeLists.txt new file mode 100644 index 00000000000..d05322e79f6 --- /dev/null +++ b/vespalib/src/tests/rw_spin_lock/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_rw_spin_lock_test_app TEST + SOURCES + rw_spin_lock_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_rw_spin_lock_test_app COMMAND vespalib_rw_spin_lock_test_app) diff --git a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp new file mode 100644 index 00000000000..207284d0db0 --- /dev/null +++ b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp @@ -0,0 +1,355 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/spin_lock.h> +#include <vespa/vespalib/util/rw_spin_lock.h> +#include <vespa/vespalib/util/atomic.h> +#include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/util/classname.h> +#include <vespa/vespalib/test/thread_meets.h> +#include <vespa/vespalib/testkit/test_kit.h> +#include <type_traits> +#include <ranges> +#include <random> +#include <array> + +using namespace vespalib; +using namespace vespalib::atomic; + +duration budget = 250ms; +constexpr size_t LOOP_CNT = 4096; +constexpr double LOOP_FACTOR = double(LOOP_CNT); + +//----------------------------------------------------------------------------- + +struct DummyLock { + constexpr DummyLock() noexcept {} + // BasicLockable + constexpr void lock() noexcept {} + constexpr void unlock() noexcept {} + // SharedLockable + constexpr void lock_shared() noexcept {} + [[nodiscard]] constexpr bool try_lock_shared() noexcept { return true; } + constexpr void unlock_shared() noexcept {} + // rw_upgrade_downgrade_lock + [[nodiscard]] constexpr bool try_convert_read_to_write() noexcept { return true; } + constexpr void convert_write_to_read() noexcept {} +}; + +//----------------------------------------------------------------------------- + +struct MyState { + static constexpr size_t SZ = 5; + std::array<size_t,SZ> state = {0,0,0,0,0}; + std::atomic<size_t> inconsistent_reads = 0; + void update() { + std::array<size_t,SZ> tmp; + for (size_t i = 0; i < SZ; ++i) { + tmp[i] = load_ref_relaxed(state[i]); + } + for (int n = 0; n < 1024; ++n) { + for (size_t i = 0; i < SZ; ++i) { + store_ref_relaxed(state[i], tmp[i] + 1); + } + } + } + void peek() { + std::array<size_t,SZ> tmp; + for (size_t i = 0; i < SZ; ++i) { + tmp[i] = load_ref_relaxed(state[i]); + } + for (int n = 0; n < 1024; ++n) { + for (size_t i = 0; i < SZ; ++i) { + if (load_ref_relaxed(state[i]) != tmp[i]) [[unlikely]] { + inconsistent_reads.fetch_add(1, std::memory_order_relaxed); + } + } + } + } + bool check(size_t expect) const { + if (inconsistent_reads > 0) { + return false; + } + for (const auto& value: state) { + if (load_ref_relaxed(value) != expect) { + return false; + } + } + return true; + } + void report(size_t expect, const char *name) const { + if (check(expect)) { + fprintf(stderr, "%s is thread safe\n", name); + } else { + fprintf(stderr, "%s is not thread safe\n", name); + fprintf(stderr, " inconsistent reads: %zu\n", inconsistent_reads.load()); + fprintf(stderr, " expected %zu, got [%zu,%zu,%zu,%zu,%zu]\n", + expect, state[0], state[1], state[2], state[3], state[4]); + } + } +}; + +// do work while waiting for other threads to be ready +class ActiveBarrier : Rendezvous<bool,bool> { +private: + std::atomic<uint32_t> _ready_cnt; + void mingle() override { + _ready_cnt.store(0, std::memory_order_relaxed); + } +public: + ActiveBarrier(size_t n) : Rendezvous<bool,bool>(n), _ready_cnt(0) {} + void operator()(auto &do_work) { + if (_ready_cnt.fetch_add(1, std::memory_order_relaxed) + 1 < size()) { + do_work(); + } + while (_ready_cnt.load(std::memory_order_relaxed) < size()) { + do_work(); + } + rendezvous(false); + } +}; + +// random generator used to make per-thread decisions +class Rnd { +private: + std::mt19937 _engine; + std::uniform_int_distribution<int> _dist; +public: + Rnd(uint32_t seed) : _engine(seed), _dist(0,9999) {} + bool operator()(int bp) { return _dist(_engine) < bp; } +}; + +void fork_join(auto &&thread_fun, size_t num_threads) { + assert(num_threads > 0); + ThreadPool pool; + for (size_t i = 1; i < num_threads; ++i) { + pool.start([i,&thread_fun]{ thread_fun(i); }); + } + thread_fun(0); + pool.join(); +} + +auto apply_merge(auto &&inputs, auto &&perform, auto &&merge) { + using output_t = std::decay_t<decltype(perform(*inputs.begin()))>; + std::mutex lock; + std::optional<output_t> output; + auto handle_result = [&](output_t result) { + std::lock_guard guard(lock); + if (output.has_value()) { + output = merge(std::move(output).value(), std::move(result)); + } else { + output = std::move(result); + } + }; + ThreadPool pool; + for (auto &&item: inputs) { + pool.start([item,&perform,&handle_result]{ handle_result(perform(item)); }); + } + pool.join(); + return output.value(); +} + +//----------------------------------------------------------------------------- + +template<typename T> +concept basic_lockable = requires(T a) { + { a.lock() } -> std::same_as<void>; + { a.unlock() } -> std::same_as<void>; +}; + +template<typename T> +concept lockable = requires(T a) { + { a.try_lock() } -> std::same_as<bool>; + { a.lock() } -> std::same_as<void>; + { a.unlock() } -> std::same_as<void>; +}; + +template<typename T> +concept shared_lockable = requires(T a) { + { a.try_lock_shared() } -> std::same_as<bool>; + { a.lock_shared() } -> std::same_as<void>; + { a.unlock_shared() } -> std::same_as<void>; +}; + +template<typename T> +concept can_upgrade = requires(std::shared_lock<T> a, std::unique_lock<T> b) { + { try_upgrade(std::move(a)) } -> std::same_as<std::unique_lock<T>>; + { downgrade(std::move(b)) } -> std::same_as<std::shared_lock<T>>; +}; + +//----------------------------------------------------------------------------- + +template <size_t N> +auto run_loop(auto &f) { + static_assert(N % 4 == 0); + for (size_t i = 0; i < N / 4; ++i) { + f(); f(); f(); f(); + } +} + +double measure_ns(auto &work) __attribute__((noinline)); +double measure_ns(auto &work) { + auto t0 = steady_clock::now(); + run_loop<LOOP_CNT>(work); + return count_ns(steady_clock::now() - t0) / LOOP_FACTOR; +} + +struct BenchmarkResult { + double cost_ns = std::numeric_limits<double>::max(); + double range_ns = 0.0; +}; + +struct Meets { + vespalib::test::ThreadMeets::Vote vote; + vespalib::test::ThreadMeets::Avg avg; + vespalib::test::ThreadMeets::Range<double> range; + ActiveBarrier active_wait; + Meets(size_t num_threads) + : vote(num_threads), avg(num_threads), range(num_threads), active_wait(num_threads) {} + ~Meets(); +}; +Meets::~Meets() = default; + +BenchmarkResult benchmark_ns(auto &&work, size_t num_threads = 1) { + Timer timer; + Meets meets(num_threads); + BenchmarkResult result; + auto hook = [&](size_t thread_id) { + for (bool once_more = true; meets.vote(once_more); once_more = (timer.elapsed() < budget)) { + auto my_ns = measure_ns(work); + meets.active_wait(work); + auto cost_ns = meets.avg(my_ns); + auto range_ns = meets.range(my_ns); + if (thread_id == 0 && cost_ns < result.cost_ns) { + result.cost_ns = cost_ns; + result.range_ns = range_ns; + } + } + }; + fork_join(hook, num_threads); + return result; +} + +//----------------------------------------------------------------------------- + +template <typename T> +void estimate_cost() { + T lock; + auto name = getClassName(lock); + static_assert(basic_lockable<T>); + fprintf(stderr, "%s unique lock/unlock: %g ns\n", name.c_str(), + benchmark_ns([&lock]{ lock.lock(); lock.unlock(); }).cost_ns); + if constexpr (shared_lockable<T>) { + fprintf(stderr, "%s shared lock/unlock: %g ns\n", name.c_str(), + benchmark_ns([&lock]{ lock.lock_shared(); lock.unlock_shared(); }).cost_ns); + } + if constexpr (can_upgrade<T>) { + auto guard = std::shared_lock(lock); + fprintf(stderr, "%s upgrade/downgrade: %g ns\n", name.c_str(), + benchmark_ns([&lock]{ + assert(lock.try_convert_read_to_write()); + lock.convert_write_to_read(); + }).cost_ns); + } +} + +//----------------------------------------------------------------------------- + +template <typename T> +size_t thread_safety_loop(T &lock, MyState &state, Meets &meets, int read_bp, size_t thread_id) { + Timer timer; + Rnd rnd(thread_id); + size_t write_cnt = 0; + BenchmarkResult result; + auto do_work = [&] + { + if (rnd(read_bp)) { + if constexpr (shared_lockable<T>) { + std::shared_lock guard(lock); + state.peek(); + } else { + std::lock_guard guard(lock); + state.peek(); + } + } else { + { + std::lock_guard guard(lock); + state.update(); + } + ++write_cnt; + } + }; + for (bool once_more = true; meets.vote(once_more); once_more = (timer.elapsed() < budget)) { + auto my_est = measure_ns(do_work); + meets.active_wait(do_work); + auto cost_ns = meets.avg(my_est); + auto range_ns = meets.range(my_est); + if (cost_ns < result.cost_ns) { + result.cost_ns = cost_ns; + result.range_ns = range_ns; + } + } + if (thread_id == 0) { + fprintf(stderr, "---> %s with %2zu threads (%5d bp r): %12.2f ns, range: %12.2f ns\n", + getClassName(lock).c_str(), meets.vote.size(), read_bp, result.cost_ns, result.range_ns); + } + return write_cnt; +} + +//----------------------------------------------------------------------------- + +TEST("require that rw spin locks can be used with lock_guard, unique_lock and shared_lock") { + static_assert(basic_lockable<RWSpinLock>); + static_assert(lockable<RWSpinLock>); + static_assert(shared_lockable<RWSpinLock>); + static_assert(can_upgrade<RWSpinLock>); + RWSpinLock lock; + { auto guard = std::lock_guard(lock); } + { auto guard = std::unique_lock(lock); } + { auto guard = std::shared_lock(lock); } +} + +TEST("estimate basic cost") { + Rnd rnd(123); + MyState state; + fprintf(stderr, " rnd cost: %8.2f ns\n", benchmark_ns([&]{ rnd(50); }).cost_ns); + fprintf(stderr, " peek cost: %8.2f ns\n", benchmark_ns([&]{ state.peek(); }).cost_ns); + fprintf(stderr, "update cost: %8.2f ns\n", benchmark_ns([&]{ state.update(); }).cost_ns); +} + +void benchmark_lock(auto &lock) { + size_t expect = 0; + auto state = std::make_unique<MyState>(); + for (size_t bp: {10000, 9999, 5000, 0}) { + for (size_t num_threads: {8, 4, 2, 1}) { + Meets meets(num_threads); + auto hook = [&](size_t thread_id) { + return thread_safety_loop(lock, *state, meets, bp, thread_id); + }; + expect += apply_merge(std::views::iota(size_t(0), num_threads), hook, [](auto a, auto b){ return a + b; }); + } + } + state->report(expect, getClassName(lock).c_str()); + EXPECT_TRUE(state->check(expect)); +} + +TEST_F("benchmark RWSpinLock", RWSpinLock()) { benchmark_lock(f1); } +TEST_F("benchmark std::shared_mutex", std::shared_mutex()) { benchmark_lock(f1); } +TEST_F("benchmark std::mutex", std::mutex()) { benchmark_lock(f1); } +TEST_F("benchmark SpinLock", SpinLock()) { benchmark_lock(f1); } + +TEST("estimate single-threaded lock/unlock cost") { + estimate_cost<DummyLock>(); + estimate_cost<SpinLock>(); + estimate_cost<std::mutex>(); + estimate_cost<RWSpinLock>(); + estimate_cost<std::shared_mutex>(); +} + +int main(int argc, char **argv) { + TEST_MASTER.init(__FILE__); + if ((argc == 2) && (argv[1] == std::string("bench"))) { + budget = 5s; + } + TEST_RUN_ALL(); + return (TEST_MASTER.fini() ? 0 : 1); +} diff --git a/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp b/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp index dfcba14ba63..910c2d017ba 100644 --- a/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp +++ b/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/util/shared_string_repo.h> -#include <vespa/vespalib/util/rendezvous.h> +#include <vespa/vespalib/test/thread_meets.h> #include <vespa/vespalib/util/time.h> #include <vespa/vespalib/util/size_literals.h> #include <vespa/vespalib/util/stringfmt.h> @@ -115,41 +115,8 @@ std::unique_ptr<StringIdVector> make_weak_handles(const Handles &handles) { //----------------------------------------------------------------------------- -struct Avg : Rendezvous<double, double> { - explicit Avg(size_t n) : Rendezvous<double, double>(n) {} - void mingle() override { - double sum = 0; - for (size_t i = 0; i < size(); ++i) { - sum += in(i); - } - double result = sum / size(); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } - } - double operator()(double value) { return rendezvous(value); } -}; - -struct Vote : Rendezvous<bool, bool> { - explicit Vote(size_t n) : Rendezvous<bool, bool>(n) {} - void mingle() override { - size_t true_cnt = 0; - size_t false_cnt = 0; - for (size_t i = 0; i < size(); ++i) { - if (in(i)) { - ++true_cnt; - } else { - ++false_cnt; - } - } - bool result = (true_cnt > false_cnt); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } - } - [[nodiscard]] size_t num_threads() const { return size(); } - bool operator()(bool flag) { return rendezvous(flag); } -}; +using Avg = vespalib::test::ThreadMeets::Avg; +using Vote = vespalib::test::ThreadMeets::Vote; //----------------------------------------------------------------------------- @@ -174,7 +141,7 @@ struct Fixture { : avg(num_threads), vote(num_threads), work(make_strings(work_size)), direct_work(make_direct_strings(work_size)), start_time(steady_clock::now()) {} ~Fixture() { if (verbose) { - fprintf(stderr, "benchmark results for %zu threads:\n", vote.num_threads()); + fprintf(stderr, "benchmark results for %zu threads:\n", vote.size()); for (const auto &[tag, ms_cost]: time_ms) { fprintf(stderr, " %s: %g ms\n", tag.c_str(), ms_cost); } diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.cpp b/vespalib/src/vespa/vespalib/test/thread_meets.cpp index 9d23e0eab28..607179c53f9 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.cpp +++ b/vespalib/src/vespa/vespalib/test/thread_meets.cpp @@ -9,4 +9,35 @@ ThreadMeets::Nop::mingle() { } +void +ThreadMeets::Avg::mingle() +{ + double sum = 0; + for (size_t i = 0; i < size(); ++i) { + sum += in(i); + } + double result = sum / size(); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } +} + +void +ThreadMeets::Vote::mingle() +{ + size_t true_cnt = 0; + size_t false_cnt = 0; + for (size_t i = 0; i < size(); ++i) { + if (in(i)) { + ++true_cnt; + } else { + ++false_cnt; + } + } + bool result = (true_cnt > false_cnt); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } +} + } diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.h b/vespalib/src/vespa/vespalib/test/thread_meets.h index 62ca7779935..7ef4dcb9921 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.h +++ b/vespalib/src/vespa/vespalib/test/thread_meets.h @@ -12,10 +12,67 @@ namespace vespalib::test { struct ThreadMeets { // can be used as a simple thread barrier struct Nop : vespalib::Rendezvous<bool,bool> { - Nop(size_t N) : vespalib::Rendezvous<bool,bool>(N) {} + explicit Nop(size_t N) : vespalib::Rendezvous<bool,bool>(N) {} void operator()() { rendezvous(false); } void mingle() override; }; + // calculate the average value across threads + struct Avg : Rendezvous<double, double> { + explicit Avg(size_t n) : Rendezvous<double, double>(n) {} + double operator()(double value) { return rendezvous(value); } + void mingle() override; + }; + // threads vote for true/false, majority wins (false on tie) + struct Vote : Rendezvous<bool, bool> { + explicit Vote(size_t n) : Rendezvous<bool, bool>(n) {} + bool operator()(bool flag) { return rendezvous(flag); } + void mingle() override; + }; + // sum of values across all threads + template <typename T> + struct Sum : vespalib::Rendezvous<T,T> { + using vespalib::Rendezvous<T,T>::in; + using vespalib::Rendezvous<T,T>::out; + using vespalib::Rendezvous<T,T>::size; + using vespalib::Rendezvous<T,T>::rendezvous; + explicit Sum(size_t N) : vespalib::Rendezvous<T,T>(N) {} + T operator()(T value) { return rendezvous(value); } + void mingle() override { + T acc{}; + for (size_t i = 0; i < size(); ++i) { + acc += in(i); + } + for (size_t i = 0; i < size(); ++i) { + out(i) = acc; + } + } + }; + // range of values across all threads + template <typename T> + struct Range : vespalib::Rendezvous<T,T> { + using vespalib::Rendezvous<T,T>::in; + using vespalib::Rendezvous<T,T>::out; + using vespalib::Rendezvous<T,T>::size; + using vespalib::Rendezvous<T,T>::rendezvous; + explicit Range(size_t N) : vespalib::Rendezvous<T,T>(N) {} + T operator()(T value) { return rendezvous(value); } + void mingle() override { + T min = in(0); + T max = in(0); + for (size_t i = 1; i < size(); ++i) { + if (in(i) < min) { + min = in(i); + } + if (in(i) > max) { + max = in(i); + } + } + T result = (max - min); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } + } + }; // swap values between 2 threads template <typename T> struct Swap : vespalib::Rendezvous<T,T> { @@ -25,8 +82,8 @@ struct ThreadMeets { Swap() : vespalib::Rendezvous<T,T>(2) {} T operator()(T input) { return rendezvous(input); } void mingle() override { - out(1) = in(0); - out(0) = in(1); + out(1) = std::move(in(0)); + out(0) = std::move(in(1)); } }; }; diff --git a/vespalib/src/vespa/vespalib/util/rendezvous.h b/vespalib/src/vespa/vespalib/util/rendezvous.h index 2880f325d96..17a8729c54c 100644 --- a/vespalib/src/vespa/vespalib/util/rendezvous.h +++ b/vespalib/src/vespa/vespalib/util/rendezvous.h @@ -50,14 +50,6 @@ private: protected: /** - * Obtain the number of input and output values to be handled by - * mingle. This function is called by mingle. - * - * @return number of input and output values - **/ - size_t size() const { return _size; } - - /** * Obtain an input parameter. This function is called by mingle. * * @return reference to the appropriate input @@ -87,6 +79,11 @@ public: virtual ~Rendezvous(); /** + * @return number of participants + **/ + size_t size() const { return _size; } + + /** * Called by individual threads to synchronize execution and share * state with the mingle function. * diff --git a/vespalib/src/vespa/vespalib/util/rw_spin_lock.h b/vespalib/src/vespa/vespalib/util/rw_spin_lock.h new file mode 100644 index 00000000000..f2c15dcc0eb --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/rw_spin_lock.h @@ -0,0 +1,189 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <mutex> +#include <shared_mutex> +#include <atomic> +#include <thread> +#include <cassert> +#include <utility> + +namespace vespalib { + +/** + * A reader-writer spin lock implementation. + * + * reader: shared access for any number of readers + * writer: exclusive access for a single writer + * + * valid lock combinations: + * {} + * {N readers} + * {1 writer} + * + * Trying to obtain a write lock will lead to not granting new read + * locks. + * + * This lock is intended for use-cases that involves mostly reading, + * with a little bit of writing. + * + * This class implements the Lockable and SharedLockable named + * requirements from the standard library, making it directly usable + * with std::shared_lock (reader) and std::unique_lock (writer) + * + * There is also some special glue added for lock upgrading and + * downgrading. + * + * NOTE: this implementation is experimental, mostly intended for + * benchmarking and trying to identify use-cases that work with + * rw locks. Upgrade locks that do not block readers might be + * implementet in the future. + **/ +class RWSpinLock { +private: + // [31: num readers][1: pending writer] + // a reader gets the lock by: + // increasing the number of readers while the pending writer bit is not set. + // a writer gets the lock by: + // changing the pending writer bit from 0 to 1 and then + // waiting for the number of readers to become 0 + // an upgrade is successful when: + // a reader is able to obtain the pending writer bit + std::atomic<uint32_t> _state; + + // Convenience function used to check if the pending writer bit is + // set in the given value. + bool has_pending_writer(uint32_t value) noexcept { + return (value & 1); + } + + // Wait for all readers to release their locks. + void wait_for_zero_readers(uint32_t &value) { + while (value != 1) { + std::this_thread::yield(); + value = _state.load(std::memory_order_acquire); + } + } + +public: + RWSpinLock() noexcept : _state(0) { + static_assert(std::atomic<uint32_t>::is_always_lock_free); + } + + // implementation of Lockable named requirement - vvv + + void lock() noexcept { + uint32_t expected = 0; + uint32_t desired = 1; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + while (has_pending_writer(expected)) { + std::this_thread::yield(); + expected = _state.load(std::memory_order_relaxed); + } + desired = expected + 1; + } + wait_for_zero_readers(desired); + } + + [[nodiscard]] bool try_lock() noexcept { + uint32_t expected = 0; + return _state.compare_exchange_strong(expected, 1, + std::memory_order_acquire, + std::memory_order_relaxed); + } + + void unlock() noexcept { + _state.store(0, std::memory_order_release); + } + + // implementation of Lockable named requirement - ^^^ + + // implementation of SharedLockable named requirement - vvv + + void lock_shared() noexcept { + uint32_t expected = 0; + uint32_t desired = 2; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + while (has_pending_writer(expected)) { + std::this_thread::yield(); + expected = _state.load(std::memory_order_relaxed); + } + desired = expected + 2; + } + } + + [[nodiscard]] bool try_lock_shared() noexcept { + uint32_t expected = 0; + uint32_t desired = 2; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + if (has_pending_writer(expected)) { + return false; + } + desired = expected + 2; + } + return true; + } + + void unlock_shared() noexcept { + _state.fetch_sub(2, std::memory_order_release); + } + + // implementation of SharedLockable named requirement - ^^^ + + // try to upgrade a read (shared) lock to a write (unique) lock + bool try_convert_read_to_write() noexcept { + uint32_t expected = 2; + uint32_t desired = 1; + while (!_state.compare_exchange_weak(expected, desired, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + if (has_pending_writer(expected)) { + return false; + } + desired = expected - 1; + } + wait_for_zero_readers(desired); + return true; + } + + // convert a write (unique) lock to a read (shared) lock + void convert_write_to_read() noexcept { + _state.store(2, std::memory_order_release); + } +}; + +template<typename T> +concept rw_upgrade_downgrade_lock = requires(T a, T b) { + { a.try_convert_read_to_write() } -> std::same_as<bool>; + { b.convert_write_to_read() } -> std::same_as<void>; +}; + +template <rw_upgrade_downgrade_lock T> +[[nodiscard]] std::unique_lock<T> try_upgrade(std::shared_lock<T> &&guard) noexcept { + assert(guard.owns_lock()); + if (guard.mutex()->try_convert_read_to_write()) { + return {*guard.release(), std::adopt_lock}; + } else { + return {}; + } +} + +template <rw_upgrade_downgrade_lock T> +[[nodiscard]] std::shared_lock<T> downgrade(std::unique_lock<T> &&guard) noexcept { + assert(guard.owns_lock()); + guard.mutex()->convert_write_to_read(); + return {*guard.release(), std::adopt_lock}; +} + +} |