diff options
-rw-r--r-- | vespalib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | vespalib/src/tests/rw_spin_lock/CMakeLists.txt | 8 | ||||
-rw-r--r-- | vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp | 355 | ||||
-rw-r--r-- | vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp | 41 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/test/thread_meets.cpp | 31 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/test/thread_meets.h | 63 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/rendezvous.h | 13 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/rw_spin_lock.h | 189 |
8 files changed, 48 insertions, 653 deletions
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index c1f6f2cbbff..6d19988b96b 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -147,7 +147,6 @@ vespa_define_module( src/tests/require src/tests/runnable_pair src/tests/rusage - src/tests/rw_spin_lock src/tests/sequencedtaskexecutor src/tests/sha1 src/tests/shared_operation_throttler diff --git a/vespalib/src/tests/rw_spin_lock/CMakeLists.txt b/vespalib/src/tests/rw_spin_lock/CMakeLists.txt deleted file mode 100644 index d05322e79f6..00000000000 --- a/vespalib/src/tests/rw_spin_lock/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(vespalib_rw_spin_lock_test_app TEST - SOURCES - rw_spin_lock_test.cpp - DEPENDS - vespalib -) -vespa_add_test(NAME vespalib_rw_spin_lock_test_app COMMAND vespalib_rw_spin_lock_test_app) diff --git a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp deleted file mode 100644 index 207284d0db0..00000000000 --- a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp +++ /dev/null @@ -1,355 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include <vespa/vespalib/util/spin_lock.h> -#include <vespa/vespalib/util/rw_spin_lock.h> -#include <vespa/vespalib/util/atomic.h> -#include <vespa/vespalib/util/time.h> -#include <vespa/vespalib/util/classname.h> -#include <vespa/vespalib/test/thread_meets.h> -#include <vespa/vespalib/testkit/test_kit.h> -#include <type_traits> -#include <ranges> -#include <random> -#include <array> - -using namespace vespalib; -using namespace vespalib::atomic; - -duration budget = 250ms; -constexpr size_t LOOP_CNT = 4096; -constexpr double LOOP_FACTOR = double(LOOP_CNT); - -//----------------------------------------------------------------------------- - -struct DummyLock { - constexpr DummyLock() noexcept {} - // BasicLockable - constexpr void lock() noexcept {} - constexpr void unlock() noexcept {} - // SharedLockable - constexpr void lock_shared() noexcept {} - [[nodiscard]] constexpr bool try_lock_shared() noexcept { return true; } - constexpr void unlock_shared() noexcept {} - // rw_upgrade_downgrade_lock - [[nodiscard]] constexpr bool try_convert_read_to_write() noexcept { return true; } - constexpr void convert_write_to_read() noexcept {} -}; - -//----------------------------------------------------------------------------- - -struct MyState { - static constexpr size_t SZ = 5; - std::array<size_t,SZ> state = {0,0,0,0,0}; - std::atomic<size_t> inconsistent_reads = 0; - void update() { - std::array<size_t,SZ> tmp; - for (size_t i = 0; i < SZ; ++i) { - tmp[i] = load_ref_relaxed(state[i]); - } - for (int n = 0; n < 1024; ++n) { - for (size_t i = 0; i < SZ; ++i) { - store_ref_relaxed(state[i], tmp[i] + 1); - } - } - } - void peek() { - std::array<size_t,SZ> tmp; - for (size_t i = 0; i < SZ; ++i) { - tmp[i] = load_ref_relaxed(state[i]); - } - for (int n = 0; n < 1024; ++n) { - for (size_t i = 0; i < SZ; ++i) { - if (load_ref_relaxed(state[i]) != tmp[i]) [[unlikely]] { - inconsistent_reads.fetch_add(1, std::memory_order_relaxed); - } - } - } - } - bool check(size_t expect) const { - if (inconsistent_reads > 0) { - return false; - } - for (const auto& value: state) { - if (load_ref_relaxed(value) != expect) { - return false; - } - } - return true; - } - void report(size_t expect, const char *name) const { - if (check(expect)) { - fprintf(stderr, "%s is thread safe\n", name); - } else { - fprintf(stderr, "%s is not thread safe\n", name); - fprintf(stderr, " inconsistent reads: %zu\n", inconsistent_reads.load()); - fprintf(stderr, " expected %zu, got [%zu,%zu,%zu,%zu,%zu]\n", - expect, state[0], state[1], state[2], state[3], state[4]); - } - } -}; - -// do work while waiting for other threads to be ready -class ActiveBarrier : Rendezvous<bool,bool> { -private: - std::atomic<uint32_t> _ready_cnt; - void mingle() override { - _ready_cnt.store(0, std::memory_order_relaxed); - } -public: - ActiveBarrier(size_t n) : Rendezvous<bool,bool>(n), _ready_cnt(0) {} - void operator()(auto &do_work) { - if (_ready_cnt.fetch_add(1, std::memory_order_relaxed) + 1 < size()) { - do_work(); - } - while (_ready_cnt.load(std::memory_order_relaxed) < size()) { - do_work(); - } - rendezvous(false); - } -}; - -// random generator used to make per-thread decisions -class Rnd { -private: - std::mt19937 _engine; - std::uniform_int_distribution<int> _dist; -public: - Rnd(uint32_t seed) : _engine(seed), _dist(0,9999) {} - bool operator()(int bp) { return _dist(_engine) < bp; } -}; - -void fork_join(auto &&thread_fun, size_t num_threads) { - assert(num_threads > 0); - ThreadPool pool; - for (size_t i = 1; i < num_threads; ++i) { - pool.start([i,&thread_fun]{ thread_fun(i); }); - } - thread_fun(0); - pool.join(); -} - -auto apply_merge(auto &&inputs, auto &&perform, auto &&merge) { - using output_t = std::decay_t<decltype(perform(*inputs.begin()))>; - std::mutex lock; - std::optional<output_t> output; - auto handle_result = [&](output_t result) { - std::lock_guard guard(lock); - if (output.has_value()) { - output = merge(std::move(output).value(), std::move(result)); - } else { - output = std::move(result); - } - }; - ThreadPool pool; - for (auto &&item: inputs) { - pool.start([item,&perform,&handle_result]{ handle_result(perform(item)); }); - } - pool.join(); - return output.value(); -} - -//----------------------------------------------------------------------------- - -template<typename T> -concept basic_lockable = requires(T a) { - { a.lock() } -> std::same_as<void>; - { a.unlock() } -> std::same_as<void>; -}; - -template<typename T> -concept lockable = requires(T a) { - { a.try_lock() } -> std::same_as<bool>; - { a.lock() } -> std::same_as<void>; - { a.unlock() } -> std::same_as<void>; -}; - -template<typename T> -concept shared_lockable = requires(T a) { - { a.try_lock_shared() } -> std::same_as<bool>; - { a.lock_shared() } -> std::same_as<void>; - { a.unlock_shared() } -> std::same_as<void>; -}; - -template<typename T> -concept can_upgrade = requires(std::shared_lock<T> a, std::unique_lock<T> b) { - { try_upgrade(std::move(a)) } -> std::same_as<std::unique_lock<T>>; - { downgrade(std::move(b)) } -> std::same_as<std::shared_lock<T>>; -}; - -//----------------------------------------------------------------------------- - -template <size_t N> -auto run_loop(auto &f) { - static_assert(N % 4 == 0); - for (size_t i = 0; i < N / 4; ++i) { - f(); f(); f(); f(); - } -} - -double measure_ns(auto &work) __attribute__((noinline)); -double measure_ns(auto &work) { - auto t0 = steady_clock::now(); - run_loop<LOOP_CNT>(work); - return count_ns(steady_clock::now() - t0) / LOOP_FACTOR; -} - -struct BenchmarkResult { - double cost_ns = std::numeric_limits<double>::max(); - double range_ns = 0.0; -}; - -struct Meets { - vespalib::test::ThreadMeets::Vote vote; - vespalib::test::ThreadMeets::Avg avg; - vespalib::test::ThreadMeets::Range<double> range; - ActiveBarrier active_wait; - Meets(size_t num_threads) - : vote(num_threads), avg(num_threads), range(num_threads), active_wait(num_threads) {} - ~Meets(); -}; -Meets::~Meets() = default; - -BenchmarkResult benchmark_ns(auto &&work, size_t num_threads = 1) { - Timer timer; - Meets meets(num_threads); - BenchmarkResult result; - auto hook = [&](size_t thread_id) { - for (bool once_more = true; meets.vote(once_more); once_more = (timer.elapsed() < budget)) { - auto my_ns = measure_ns(work); - meets.active_wait(work); - auto cost_ns = meets.avg(my_ns); - auto range_ns = meets.range(my_ns); - if (thread_id == 0 && cost_ns < result.cost_ns) { - result.cost_ns = cost_ns; - result.range_ns = range_ns; - } - } - }; - fork_join(hook, num_threads); - return result; -} - -//----------------------------------------------------------------------------- - -template <typename T> -void estimate_cost() { - T lock; - auto name = getClassName(lock); - static_assert(basic_lockable<T>); - fprintf(stderr, "%s unique lock/unlock: %g ns\n", name.c_str(), - benchmark_ns([&lock]{ lock.lock(); lock.unlock(); }).cost_ns); - if constexpr (shared_lockable<T>) { - fprintf(stderr, "%s shared lock/unlock: %g ns\n", name.c_str(), - benchmark_ns([&lock]{ lock.lock_shared(); lock.unlock_shared(); }).cost_ns); - } - if constexpr (can_upgrade<T>) { - auto guard = std::shared_lock(lock); - fprintf(stderr, "%s upgrade/downgrade: %g ns\n", name.c_str(), - benchmark_ns([&lock]{ - assert(lock.try_convert_read_to_write()); - lock.convert_write_to_read(); - }).cost_ns); - } -} - -//----------------------------------------------------------------------------- - -template <typename T> -size_t thread_safety_loop(T &lock, MyState &state, Meets &meets, int read_bp, size_t thread_id) { - Timer timer; - Rnd rnd(thread_id); - size_t write_cnt = 0; - BenchmarkResult result; - auto do_work = [&] - { - if (rnd(read_bp)) { - if constexpr (shared_lockable<T>) { - std::shared_lock guard(lock); - state.peek(); - } else { - std::lock_guard guard(lock); - state.peek(); - } - } else { - { - std::lock_guard guard(lock); - state.update(); - } - ++write_cnt; - } - }; - for (bool once_more = true; meets.vote(once_more); once_more = (timer.elapsed() < budget)) { - auto my_est = measure_ns(do_work); - meets.active_wait(do_work); - auto cost_ns = meets.avg(my_est); - auto range_ns = meets.range(my_est); - if (cost_ns < result.cost_ns) { - result.cost_ns = cost_ns; - result.range_ns = range_ns; - } - } - if (thread_id == 0) { - fprintf(stderr, "---> %s with %2zu threads (%5d bp r): %12.2f ns, range: %12.2f ns\n", - getClassName(lock).c_str(), meets.vote.size(), read_bp, result.cost_ns, result.range_ns); - } - return write_cnt; -} - -//----------------------------------------------------------------------------- - -TEST("require that rw spin locks can be used with lock_guard, unique_lock and shared_lock") { - static_assert(basic_lockable<RWSpinLock>); - static_assert(lockable<RWSpinLock>); - static_assert(shared_lockable<RWSpinLock>); - static_assert(can_upgrade<RWSpinLock>); - RWSpinLock lock; - { auto guard = std::lock_guard(lock); } - { auto guard = std::unique_lock(lock); } - { auto guard = std::shared_lock(lock); } -} - -TEST("estimate basic cost") { - Rnd rnd(123); - MyState state; - fprintf(stderr, " rnd cost: %8.2f ns\n", benchmark_ns([&]{ rnd(50); }).cost_ns); - fprintf(stderr, " peek cost: %8.2f ns\n", benchmark_ns([&]{ state.peek(); }).cost_ns); - fprintf(stderr, "update cost: %8.2f ns\n", benchmark_ns([&]{ state.update(); }).cost_ns); -} - -void benchmark_lock(auto &lock) { - size_t expect = 0; - auto state = std::make_unique<MyState>(); - for (size_t bp: {10000, 9999, 5000, 0}) { - for (size_t num_threads: {8, 4, 2, 1}) { - Meets meets(num_threads); - auto hook = [&](size_t thread_id) { - return thread_safety_loop(lock, *state, meets, bp, thread_id); - }; - expect += apply_merge(std::views::iota(size_t(0), num_threads), hook, [](auto a, auto b){ return a + b; }); - } - } - state->report(expect, getClassName(lock).c_str()); - EXPECT_TRUE(state->check(expect)); -} - -TEST_F("benchmark RWSpinLock", RWSpinLock()) { benchmark_lock(f1); } -TEST_F("benchmark std::shared_mutex", std::shared_mutex()) { benchmark_lock(f1); } -TEST_F("benchmark std::mutex", std::mutex()) { benchmark_lock(f1); } -TEST_F("benchmark SpinLock", SpinLock()) { benchmark_lock(f1); } - -TEST("estimate single-threaded lock/unlock cost") { - estimate_cost<DummyLock>(); - estimate_cost<SpinLock>(); - estimate_cost<std::mutex>(); - estimate_cost<RWSpinLock>(); - estimate_cost<std::shared_mutex>(); -} - -int main(int argc, char **argv) { - TEST_MASTER.init(__FILE__); - if ((argc == 2) && (argv[1] == std::string("bench"))) { - budget = 5s; - } - TEST_RUN_ALL(); - return (TEST_MASTER.fini() ? 0 : 1); -} diff --git a/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp b/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp index 910c2d017ba..dfcba14ba63 100644 --- a/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp +++ b/vespalib/src/tests/shared_string_repo/shared_string_repo_test.cpp @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/util/shared_string_repo.h> -#include <vespa/vespalib/test/thread_meets.h> +#include <vespa/vespalib/util/rendezvous.h> #include <vespa/vespalib/util/time.h> #include <vespa/vespalib/util/size_literals.h> #include <vespa/vespalib/util/stringfmt.h> @@ -115,8 +115,41 @@ std::unique_ptr<StringIdVector> make_weak_handles(const Handles &handles) { //----------------------------------------------------------------------------- -using Avg = vespalib::test::ThreadMeets::Avg; -using Vote = vespalib::test::ThreadMeets::Vote; +struct Avg : Rendezvous<double, double> { + explicit Avg(size_t n) : Rendezvous<double, double>(n) {} + void mingle() override { + double sum = 0; + for (size_t i = 0; i < size(); ++i) { + sum += in(i); + } + double result = sum / size(); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } + } + double operator()(double value) { return rendezvous(value); } +}; + +struct Vote : Rendezvous<bool, bool> { + explicit Vote(size_t n) : Rendezvous<bool, bool>(n) {} + void mingle() override { + size_t true_cnt = 0; + size_t false_cnt = 0; + for (size_t i = 0; i < size(); ++i) { + if (in(i)) { + ++true_cnt; + } else { + ++false_cnt; + } + } + bool result = (true_cnt > false_cnt); + for (size_t i = 0; i < size(); ++i) { + out(i) = result; + } + } + [[nodiscard]] size_t num_threads() const { return size(); } + bool operator()(bool flag) { return rendezvous(flag); } +}; //----------------------------------------------------------------------------- @@ -141,7 +174,7 @@ struct Fixture { : avg(num_threads), vote(num_threads), work(make_strings(work_size)), direct_work(make_direct_strings(work_size)), start_time(steady_clock::now()) {} ~Fixture() { if (verbose) { - fprintf(stderr, "benchmark results for %zu threads:\n", vote.size()); + fprintf(stderr, "benchmark results for %zu threads:\n", vote.num_threads()); for (const auto &[tag, ms_cost]: time_ms) { fprintf(stderr, " %s: %g ms\n", tag.c_str(), ms_cost); } diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.cpp b/vespalib/src/vespa/vespalib/test/thread_meets.cpp index 607179c53f9..9d23e0eab28 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.cpp +++ b/vespalib/src/vespa/vespalib/test/thread_meets.cpp @@ -9,35 +9,4 @@ ThreadMeets::Nop::mingle() { } -void -ThreadMeets::Avg::mingle() -{ - double sum = 0; - for (size_t i = 0; i < size(); ++i) { - sum += in(i); - } - double result = sum / size(); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } -} - -void -ThreadMeets::Vote::mingle() -{ - size_t true_cnt = 0; - size_t false_cnt = 0; - for (size_t i = 0; i < size(); ++i) { - if (in(i)) { - ++true_cnt; - } else { - ++false_cnt; - } - } - bool result = (true_cnt > false_cnt); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } -} - } diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.h b/vespalib/src/vespa/vespalib/test/thread_meets.h index 7ef4dcb9921..62ca7779935 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.h +++ b/vespalib/src/vespa/vespalib/test/thread_meets.h @@ -12,67 +12,10 @@ namespace vespalib::test { struct ThreadMeets { // can be used as a simple thread barrier struct Nop : vespalib::Rendezvous<bool,bool> { - explicit Nop(size_t N) : vespalib::Rendezvous<bool,bool>(N) {} + Nop(size_t N) : vespalib::Rendezvous<bool,bool>(N) {} void operator()() { rendezvous(false); } void mingle() override; }; - // calculate the average value across threads - struct Avg : Rendezvous<double, double> { - explicit Avg(size_t n) : Rendezvous<double, double>(n) {} - double operator()(double value) { return rendezvous(value); } - void mingle() override; - }; - // threads vote for true/false, majority wins (false on tie) - struct Vote : Rendezvous<bool, bool> { - explicit Vote(size_t n) : Rendezvous<bool, bool>(n) {} - bool operator()(bool flag) { return rendezvous(flag); } - void mingle() override; - }; - // sum of values across all threads - template <typename T> - struct Sum : vespalib::Rendezvous<T,T> { - using vespalib::Rendezvous<T,T>::in; - using vespalib::Rendezvous<T,T>::out; - using vespalib::Rendezvous<T,T>::size; - using vespalib::Rendezvous<T,T>::rendezvous; - explicit Sum(size_t N) : vespalib::Rendezvous<T,T>(N) {} - T operator()(T value) { return rendezvous(value); } - void mingle() override { - T acc{}; - for (size_t i = 0; i < size(); ++i) { - acc += in(i); - } - for (size_t i = 0; i < size(); ++i) { - out(i) = acc; - } - } - }; - // range of values across all threads - template <typename T> - struct Range : vespalib::Rendezvous<T,T> { - using vespalib::Rendezvous<T,T>::in; - using vespalib::Rendezvous<T,T>::out; - using vespalib::Rendezvous<T,T>::size; - using vespalib::Rendezvous<T,T>::rendezvous; - explicit Range(size_t N) : vespalib::Rendezvous<T,T>(N) {} - T operator()(T value) { return rendezvous(value); } - void mingle() override { - T min = in(0); - T max = in(0); - for (size_t i = 1; i < size(); ++i) { - if (in(i) < min) { - min = in(i); - } - if (in(i) > max) { - max = in(i); - } - } - T result = (max - min); - for (size_t i = 0; i < size(); ++i) { - out(i) = result; - } - } - }; // swap values between 2 threads template <typename T> struct Swap : vespalib::Rendezvous<T,T> { @@ -82,8 +25,8 @@ struct ThreadMeets { Swap() : vespalib::Rendezvous<T,T>(2) {} T operator()(T input) { return rendezvous(input); } void mingle() override { - out(1) = std::move(in(0)); - out(0) = std::move(in(1)); + out(1) = in(0); + out(0) = in(1); } }; }; diff --git a/vespalib/src/vespa/vespalib/util/rendezvous.h b/vespalib/src/vespa/vespalib/util/rendezvous.h index 17a8729c54c..2880f325d96 100644 --- a/vespalib/src/vespa/vespalib/util/rendezvous.h +++ b/vespalib/src/vespa/vespalib/util/rendezvous.h @@ -50,6 +50,14 @@ private: protected: /** + * Obtain the number of input and output values to be handled by + * mingle. This function is called by mingle. + * + * @return number of input and output values + **/ + size_t size() const { return _size; } + + /** * Obtain an input parameter. This function is called by mingle. * * @return reference to the appropriate input @@ -79,11 +87,6 @@ public: virtual ~Rendezvous(); /** - * @return number of participants - **/ - size_t size() const { return _size; } - - /** * Called by individual threads to synchronize execution and share * state with the mingle function. * diff --git a/vespalib/src/vespa/vespalib/util/rw_spin_lock.h b/vespalib/src/vespa/vespalib/util/rw_spin_lock.h deleted file mode 100644 index f2c15dcc0eb..00000000000 --- a/vespalib/src/vespa/vespalib/util/rw_spin_lock.h +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -#include <mutex> -#include <shared_mutex> -#include <atomic> -#include <thread> -#include <cassert> -#include <utility> - -namespace vespalib { - -/** - * A reader-writer spin lock implementation. - * - * reader: shared access for any number of readers - * writer: exclusive access for a single writer - * - * valid lock combinations: - * {} - * {N readers} - * {1 writer} - * - * Trying to obtain a write lock will lead to not granting new read - * locks. - * - * This lock is intended for use-cases that involves mostly reading, - * with a little bit of writing. - * - * This class implements the Lockable and SharedLockable named - * requirements from the standard library, making it directly usable - * with std::shared_lock (reader) and std::unique_lock (writer) - * - * There is also some special glue added for lock upgrading and - * downgrading. - * - * NOTE: this implementation is experimental, mostly intended for - * benchmarking and trying to identify use-cases that work with - * rw locks. Upgrade locks that do not block readers might be - * implementet in the future. - **/ -class RWSpinLock { -private: - // [31: num readers][1: pending writer] - // a reader gets the lock by: - // increasing the number of readers while the pending writer bit is not set. - // a writer gets the lock by: - // changing the pending writer bit from 0 to 1 and then - // waiting for the number of readers to become 0 - // an upgrade is successful when: - // a reader is able to obtain the pending writer bit - std::atomic<uint32_t> _state; - - // Convenience function used to check if the pending writer bit is - // set in the given value. - bool has_pending_writer(uint32_t value) noexcept { - return (value & 1); - } - - // Wait for all readers to release their locks. - void wait_for_zero_readers(uint32_t &value) { - while (value != 1) { - std::this_thread::yield(); - value = _state.load(std::memory_order_acquire); - } - } - -public: - RWSpinLock() noexcept : _state(0) { - static_assert(std::atomic<uint32_t>::is_always_lock_free); - } - - // implementation of Lockable named requirement - vvv - - void lock() noexcept { - uint32_t expected = 0; - uint32_t desired = 1; - while (!_state.compare_exchange_weak(expected, desired, - std::memory_order_acquire, - std::memory_order_relaxed)) - { - while (has_pending_writer(expected)) { - std::this_thread::yield(); - expected = _state.load(std::memory_order_relaxed); - } - desired = expected + 1; - } - wait_for_zero_readers(desired); - } - - [[nodiscard]] bool try_lock() noexcept { - uint32_t expected = 0; - return _state.compare_exchange_strong(expected, 1, - std::memory_order_acquire, - std::memory_order_relaxed); - } - - void unlock() noexcept { - _state.store(0, std::memory_order_release); - } - - // implementation of Lockable named requirement - ^^^ - - // implementation of SharedLockable named requirement - vvv - - void lock_shared() noexcept { - uint32_t expected = 0; - uint32_t desired = 2; - while (!_state.compare_exchange_weak(expected, desired, - std::memory_order_acquire, - std::memory_order_relaxed)) - { - while (has_pending_writer(expected)) { - std::this_thread::yield(); - expected = _state.load(std::memory_order_relaxed); - } - desired = expected + 2; - } - } - - [[nodiscard]] bool try_lock_shared() noexcept { - uint32_t expected = 0; - uint32_t desired = 2; - while (!_state.compare_exchange_weak(expected, desired, - std::memory_order_acquire, - std::memory_order_relaxed)) - { - if (has_pending_writer(expected)) { - return false; - } - desired = expected + 2; - } - return true; - } - - void unlock_shared() noexcept { - _state.fetch_sub(2, std::memory_order_release); - } - - // implementation of SharedLockable named requirement - ^^^ - - // try to upgrade a read (shared) lock to a write (unique) lock - bool try_convert_read_to_write() noexcept { - uint32_t expected = 2; - uint32_t desired = 1; - while (!_state.compare_exchange_weak(expected, desired, - std::memory_order_acquire, - std::memory_order_relaxed)) - { - if (has_pending_writer(expected)) { - return false; - } - desired = expected - 1; - } - wait_for_zero_readers(desired); - return true; - } - - // convert a write (unique) lock to a read (shared) lock - void convert_write_to_read() noexcept { - _state.store(2, std::memory_order_release); - } -}; - -template<typename T> -concept rw_upgrade_downgrade_lock = requires(T a, T b) { - { a.try_convert_read_to_write() } -> std::same_as<bool>; - { b.convert_write_to_read() } -> std::same_as<void>; -}; - -template <rw_upgrade_downgrade_lock T> -[[nodiscard]] std::unique_lock<T> try_upgrade(std::shared_lock<T> &&guard) noexcept { - assert(guard.owns_lock()); - if (guard.mutex()->try_convert_read_to_write()) { - return {*guard.release(), std::adopt_lock}; - } else { - return {}; - } -} - -template <rw_upgrade_downgrade_lock T> -[[nodiscard]] std::shared_lock<T> downgrade(std::unique_lock<T> &&guard) noexcept { - assert(guard.owns_lock()); - guard.mutex()->convert_write_to_read(); - return {*guard.release(), std::adopt_lock}; -} - -} |