diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-01-25 15:55:43 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-01-26 16:23:17 +0000 |
commit | 7c4fd0b043f0cc37bad278d31c9a3c4c324d7b23 (patch) | |
tree | 9f18108dd714dd69e9db101c411333408908a052 /vespalib | |
parent | eec4962db2f6dc302f7195c1b20a6818dbc0178a (diff) |
track coroutines waiting for values
Diffstat (limited to 'vespalib')
-rw-r--r-- | vespalib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | vespalib/src/tests/coro/waiting_for/CMakeLists.txt | 9 | ||||
-rw-r--r-- | vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp | 110 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/lazy.h | 12 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/received.h | 5 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/waiting_for.h | 108 |
6 files changed, 235 insertions, 10 deletions
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 8509d5fc382..76308260578 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -50,6 +50,7 @@ vespa_define_module( src/tests/coro/generator src/tests/coro/lazy src/tests/coro/received + src/tests/coro/waiting_for src/tests/cpu_usage src/tests/crc src/tests/crypto diff --git a/vespalib/src/tests/coro/waiting_for/CMakeLists.txt b/vespalib/src/tests/coro/waiting_for/CMakeLists.txt new file mode 100644 index 00000000000..d9eaa7eaf03 --- /dev/null +++ b/vespalib/src/tests/coro/waiting_for/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_waiting_for_test_app TEST + SOURCES + waiting_for_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_waiting_for_test_app COMMAND vespalib_waiting_for_test_app) diff --git a/vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp b/vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp new file mode 100644 index 00000000000..385d4ad24e3 --- /dev/null +++ b/vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp @@ -0,0 +1,110 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/coro/lazy.h> +#include <vespa/vespalib/coro/completion.h> +#include <vespa/vespalib/coro/waiting_for.h> +#include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::coro; + +struct AsyncService { + std::vector<WaitingFor<int>> pending; + auto get_value() { + return awaiter_for<int>([&](WaitingFor<int> handle) + { + pending.push_back(std::move(handle)); + }); + } +}; + +struct AsyncVoidService { + std::vector<void*> pending; + auto get_value() { + return awaiter_for<int>([&](WaitingFor<int> handle) + { + pending.push_back(handle.release()); + }); + } +}; + +struct SyncService { + auto get_value() { + return awaiter_for<int>([](WaitingFor<int> handle) + { + handle.set_value(42); + return handle.release_waiter(); // symmetric transfer + }); + } +}; + +template<typename Service> +Lazy<int> wait_for_value(Service &service) { + int value = co_await service.get_value(); + co_return value; +} + +template <typename T> +Lazy<T> wait_any(auto &&fun) { + T result = co_await fun(); + co_return std::move(result); +} + +TEST(WaitingForTest, wait_for_external_async_int) { + AsyncService service; + auto res = make_future(wait_for_value(service)); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout); + ASSERT_EQ(service.pending.size(), 1); + service.pending[0].set_value(42); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout); + service.pending.clear(); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready); + EXPECT_EQ(res.get(), 42); +} + +TEST(WaitingForTest, wait_for_external_async_int_via_void_ptr) { + AsyncVoidService service; + auto res = make_future(wait_for_value(service)); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout); + ASSERT_EQ(service.pending.size(), 1); + { + auto handle = WaitingFor<int>::from_pointer(service.pending[0]); + handle.set_value(42); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout); + } + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready); + EXPECT_EQ(res.get(), 42); +} + +TEST(WaitingForTest, wait_for_external_sync_int) { + SyncService service; + auto res = make_future(wait_for_value(service)); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready); + EXPECT_EQ(res.get(), 42); +} + +TEST(WaitingForTest, wait_for_move_only_value) { + auto val = std::make_unique<int>(42); + auto fun = [&val](auto handle){ handle.set_value(std::move(val)); }; // asymmetric transfer + auto res = make_future(wait_any<decltype(val)>([&fun](){ return awaiter_for<decltype(val)>(fun); })); + EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready); + EXPECT_EQ(*res.get(), 42); +} + +TEST(WaitingForTest, set_error) { + PromiseState<int> state; + WaitingFor<int> pending = WaitingFor<int>::from_state(state); + pending.set_error(std::make_exception_ptr(13)); + EXPECT_TRUE(state.result.has_error()); +} + +TEST(WaitingForTest, set_done) { + PromiseState<int> state; + WaitingFor<int> pending = WaitingFor<int>::from_state(state); + pending.set_value(5); + EXPECT_TRUE(state.result.has_value()); + pending.set_done(); + EXPECT_TRUE(state.result.was_canceled()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/coro/lazy.h b/vespalib/src/vespa/vespalib/coro/lazy.h index 974968d0c77..17077dccc9f 100644 --- a/vespalib/src/vespa/vespalib/coro/lazy.h +++ b/vespalib/src/vespa/vespalib/coro/lazy.h @@ -2,9 +2,8 @@ #pragma once -#include "received.h" +#include "waiting_for.h" -#include <concepts> #include <coroutine> #include <optional> #include <exception> @@ -27,7 +26,8 @@ namespace vespalib::coro { template <std::movable T> class [[nodiscard]] Lazy { public: - struct promise_type { + struct promise_type final : PromiseState<T> { + using PromiseState<T>::result; Lazy<T> get_return_object() { return Lazy(Handle::from_promise(*this)); } static std::suspend_always initial_suspend() noexcept { return {}; } static auto final_suspend() noexcept { @@ -47,11 +47,7 @@ public: void unhandled_exception() noexcept { result.set_error(std::current_exception()); } - Received<T> result; - std::coroutine_handle<> waiter; - promise_type(promise_type &&) = delete; - promise_type(const promise_type &) = delete; - promise_type() noexcept : result(), waiter(std::noop_coroutine()) {} + promise_type() noexcept : PromiseState<T>() {} ~promise_type(); }; using Handle = std::coroutine_handle<promise_type>; diff --git a/vespalib/src/vespa/vespalib/coro/received.h b/vespalib/src/vespa/vespalib/coro/received.h index abc66cd2a9d..305a187249c 100644 --- a/vespalib/src/vespa/vespalib/coro/received.h +++ b/vespalib/src/vespa/vespalib/coro/received.h @@ -3,6 +3,7 @@ #pragma once #include <memory> +#include <concepts> #include <variant> #include <exception> #include <stdexcept> @@ -46,8 +47,8 @@ private: } public: Received() : _value() {} - void set_value(T &&value) { _value.template emplace<1>(std::move(value)); } - void set_value(const T &value) { _value.template emplace<1>(value); } + template <typename RET> + void set_value(RET &&value) { _value.template emplace<1>(std::forward<RET>(value)); } void set_error(std::exception_ptr exception) { _value.template emplace<0>(exception); } void set_done() { _value.template emplace<0>(nullptr); } bool has_value() const { return (_value.index() == 1); } diff --git a/vespalib/src/vespa/vespalib/coro/waiting_for.h b/vespalib/src/vespa/vespalib/coro/waiting_for.h new file mode 100644 index 00000000000..2e11a9cb38c --- /dev/null +++ b/vespalib/src/vespa/vespalib/coro/waiting_for.h @@ -0,0 +1,108 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "received.h" +#include <coroutine> +#include <utility> + +namespace vespalib::coro { + +// State representing that someone (waiter) is waiting for something +// (result). This object cannot be moved or copied. +template <typename T> +struct PromiseState { + Received<T> result; + std::coroutine_handle<> waiter; + PromiseState(const PromiseState &) = delete; + PromiseState &operator=(const PromiseState &) = delete; + PromiseState(PromiseState &&) = delete; + PromiseState &operator=(PromiseState &&) = delete; + PromiseState() noexcept : result(), waiter(std::noop_coroutine()) {} + ~PromiseState(); +}; +template <typename T> +PromiseState<T>::~PromiseState() = default; + +// A thin (smart) wrapper referencing a PromiseState<T> representing +// that a coroutine is waiting for a value. This class acts as a +// receiver in order to set the result value. When the owning +// reference is deleted, the waiting coroutine will be resumed. +template <typename T> +class WaitingFor { +private: + PromiseState<T> *_state; + WaitingFor(PromiseState<T> *state) noexcept : _state(state) {} +public: + WaitingFor(WaitingFor &&rhs) noexcept : _state(std::exchange(rhs._state, nullptr)) {} + WaitingFor(WaitingFor &rhs) = delete; + WaitingFor &operator=(WaitingFor &rhs) = delete; + ~WaitingFor(); + template <typename RET> + void set_value(RET &&value) { + _state->result.set_value(std::forward<RET>(value)); + } + void set_error(std::exception_ptr exception) { + _state->result.set_error(exception); + } + void set_done() { + _state->result.set_done(); + } + std::coroutine_handle<> release_waiter() { + return std::exchange(_state->waiter, std::noop_coroutine()); + } + void *release() { + return std::exchange(_state, nullptr); + } + static WaitingFor from_pointer(void *ptr) { + PromiseState<T> *state = reinterpret_cast<PromiseState<T>*>(ptr); + return {state}; + } + static WaitingFor from_state(PromiseState<T> &state) { + return {&state}; + } +}; + +template <typename T> +WaitingFor<T>::~WaitingFor() +{ + if (_state != nullptr) { + _state->waiter.resume(); + } +} + +static_assert(receiver_of<WaitingFor<int>, int>); +static_assert(receiver_of<WaitingFor<std::unique_ptr<int>>, std::unique_ptr<int>>); + +// Create a custom awaiter that will return a value of type T when the +// coroutine is resumed. The waiting coroutine will be represented as +// a WaitingFor<T> that is passed as the only parameter to 'f'. The +// return value of 'f' is returned from await_suspend, which means it +// must be void, bool or coroutine handle. If 'f' returns a value +// indicating that the coroutine should be resumed immediately, +// WaitingFor<T>::release_waiter() must be called to avoid resume +// being called as well. Note that await_ready will always return +// false, since the coroutine needs to be suspended in order to create +// the WaitingFor<T> object needed. Also, the WaitingFor<T> api +// implies that the value will be set from the outside and thus cannot +// be ready up-front. Also note that await_resume must return T by +// value, since the awaiter containing the result is a temporary +// object. +template <typename T, typename F> +auto awaiter_for(F &&f) { + struct awaiter final : PromiseState<T> { + using PromiseState<T>::result; + using PromiseState<T>::waiter; + std::decay_t<F> fun; + awaiter(F &&f) : PromiseState<T>(), fun(std::forward<F>(f)) {} + bool await_ready() const noexcept { return false; } + T await_resume() { return std::move(result).get_value(); } + decltype(auto) await_suspend(std::coroutine_handle<> handle) __attribute__((noinline)) { + waiter = handle; + return fun(WaitingFor<T>::from_state(*this)); + } + }; + return awaiter(std::forward<F>(f)); +} + +} |