diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2022-11-08 10:56:02 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2022-11-08 10:56:02 +0000 |
commit | c3c218051eb9b8f75c23db3ca6bad06e9cd3314a (patch) | |
tree | 68f74d19d5ab807fae5187fda5c3889fa4e800cf /vespalib | |
parent | d29cb7e64a30a93b4ab445c872449809cdde6bcd (diff) |
return value forwarding for Lazy<T>
Diffstat (limited to 'vespalib')
-rw-r--r-- | vespalib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | vespalib/src/tests/coro/lazy/lazy_test.cpp | 21 | ||||
-rw-r--r-- | vespalib/src/tests/coro/received/CMakeLists.txt | 9 | ||||
-rw-r--r-- | vespalib/src/tests/coro/received/received_test.cpp | 143 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/completion.h | 58 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/lazy.h | 64 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/received.h | 63 |
7 files changed, 268 insertions, 91 deletions
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index c498639533f..a21f623fbfa 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -45,6 +45,7 @@ vespa_define_module( src/tests/coro/detached src/tests/coro/generator src/tests/coro/lazy + src/tests/coro/received src/tests/cpu_usage src/tests/crc src/tests/crypto diff --git a/vespalib/src/tests/coro/lazy/lazy_test.cpp b/vespalib/src/tests/coro/lazy/lazy_test.cpp index ec27bf195ec..29aac4440fc 100644 --- a/vespalib/src/tests/coro/lazy/lazy_test.cpp +++ b/vespalib/src/tests/coro/lazy/lazy_test.cpp @@ -168,4 +168,25 @@ TEST(LazyTest, async_wait_with_move_only_result) { EXPECT_EQ(*(result.get_value()), 123); } +struct Refs { + Gate &gate; + Received<std::unique_ptr<int>> &result; + Refs(Gate &gate_in, Received<std::unique_ptr<int>> &result_in) + : gate(gate_in), result(result_in) {} +}; + +TEST(LazyTest, async_wait_with_move_only_result_and_move_only_lambda) { + Gate gate; + Received<std::unique_ptr<int>> result; + vespalib::ThreadStackExecutor executor(1, 128_Ki); + auto lazy = schedule_on(executor, move_only_int()); + async_wait(std::move(lazy), [refs = std::make_unique<Refs>(gate,result)](auto res) + { + refs->result = std::move(res); + refs->gate.countDown(); + }); + gate.await(); + EXPECT_EQ(*(result.get_value()), 123); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/coro/received/CMakeLists.txt b/vespalib/src/tests/coro/received/CMakeLists.txt new file mode 100644 index 00000000000..2441d557664 --- /dev/null +++ b/vespalib/src/tests/coro/received/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_received_test_app TEST + SOURCES + received_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_received_test_app COMMAND vespalib_received_test_app) diff --git a/vespalib/src/tests/coro/received/received_test.cpp b/vespalib/src/tests/coro/received/received_test.cpp new file mode 100644 index 00000000000..96d1e7942af --- /dev/null +++ b/vespalib/src/tests/coro/received/received_test.cpp @@ -0,0 +1,143 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/coro/received.h> +#include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <memory> + +using vespalib::coro::Received; + +TEST(ReceivedTest, can_store_simple_value) { + Received<int> result; + result.set_value(42); + EXPECT_TRUE(result.has_value()); + EXPECT_FALSE(result.has_error()); + EXPECT_FALSE(result.was_canceled()); + EXPECT_FALSE(result.get_error()); + EXPECT_EQ(result.get_value(), 42); +} + +TEST(ReceivedTest, can_store_error) { + Received<int> result; + auto err = std::make_exception_ptr(std::runtime_error("stuff happened")); + result.set_error(err); + EXPECT_FALSE(result.has_value()); + EXPECT_TRUE(result.has_error()); + EXPECT_FALSE(result.was_canceled()); + EXPECT_EQ(result.get_error(), err); +} + +TEST(ReceivedTest, can_store_nothing) { + Received<int> result; + result.set_done(); + EXPECT_FALSE(result.has_value()); + EXPECT_FALSE(result.has_error()); + EXPECT_TRUE(result.was_canceled()); +} + +TEST(ReceivedTest, can_store_move_only_value) { + Received<std::unique_ptr<int>> result; + result.set_value(std::make_unique<int>(42)); + EXPECT_TRUE(result.has_value()); + EXPECT_FALSE(result.has_error()); + EXPECT_FALSE(result.was_canceled()); + EXPECT_FALSE(result.get_error()); + auto res = std::move(result).get_value(); + EXPECT_EQ(*res, 42); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.get_value().get(), nullptr); +} + +TEST(ReceivedTest, can_forward_value_to_std_promise) { + Received<std::unique_ptr<int>> result; + result.set_value(std::make_unique<int>(42)); + std::promise<std::unique_ptr<int>> promise; + auto future = promise.get_future(); + result.forward(promise); + ASSERT_TRUE(future.wait_for(0ms) == std::future_status::ready); + EXPECT_EQ(*future.get(), 42); +} + +TEST(ReceivedTest, can_forward_error_to_std_promise) { + Received<int> result; + auto err = std::make_exception_ptr(std::runtime_error("stuff happened")); + result.set_error(err); + std::promise<int> promise; + auto future = promise.get_future(); + result.forward(promise); + ASSERT_TRUE(future.wait_for(0ms) == std::future_status::ready); + EXPECT_THROW(future.get(), std::runtime_error); +} + +TEST(ReceivedTest, can_forward_nothing_as_error_to_std_promise) { + Received<int> result; + result.set_done(); + std::promise<int> promise; + auto future = promise.get_future(); + result.forward(promise); + ASSERT_TRUE(future.wait_for(0ms) == std::future_status::ready); + EXPECT_THROW(future.get(), vespalib::coro::UnavailableResultException); +} + +struct MyReceiver { + std::unique_ptr<int> value; + std::exception_ptr error; + bool done; + MyReceiver() : value(), error(), done(false) {} + void set_value(std::unique_ptr<int> v) { value = std::move(v); } + void set_error(std::exception_ptr err) { error = err; } + void set_done() { done = true; } + ~MyReceiver(); +}; +MyReceiver::~MyReceiver() = default; +static_assert(vespalib::coro::receiver_of<MyReceiver,std::unique_ptr<int>>); + +TEST(ReceivedTest, can_forward_value_to_receiver) { + Received<std::unique_ptr<int>> result; + result.set_value(std::make_unique<int>(42)); + MyReceiver r; + result.forward(r); + EXPECT_EQ(*r.value, 42); + EXPECT_FALSE(r.error); + EXPECT_FALSE(r.done); +} + +TEST(ReceivedTest, can_forward_error_to_receiver) { + Received<std::unique_ptr<int>> result; + auto err = std::make_exception_ptr(std::runtime_error("stuff happened")); + result.set_error(err); + MyReceiver r; + result.forward(r); + EXPECT_EQ(r.error, err); + EXPECT_TRUE(r.value.get() == nullptr); + EXPECT_FALSE(r.done); +} + +TEST(ReceivedTest, can_forward_nothing_to_receiver) { + Received<std::unique_ptr<int>> result; + result.set_done(); + MyReceiver r; + result.forward(r); + EXPECT_TRUE(r.done); + EXPECT_FALSE(r.error); + EXPECT_TRUE(r.value.get() == nullptr); +} + +TEST(ReceivedTest, can_forward_itself_to_lvalue_lambda_callback) { + Received<std::unique_ptr<int>> result; + result.set_value(std::make_unique<int>(42)); + Received<std::unique_ptr<int>> other_result; + auto callback = [&](auto res){ other_result = std::move(res); }; + result.forward(callback); + EXPECT_EQ(*other_result.get_value(), 42); +} + +TEST(ReceivedTest, can_forward_itself_to_rvalue_lambda_callback) { + Received<std::unique_ptr<int>> result; + result.set_value(std::make_unique<int>(42)); + Received<std::unique_ptr<int>> other_result; + result.forward([&](auto res){ other_result = std::move(res); }); + EXPECT_EQ(*other_result.get_value(), 42); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/coro/completion.h b/vespalib/src/vespa/vespalib/coro/completion.h index f323d8c68bf..caac3e5fb8a 100644 --- a/vespalib/src/vespa/vespalib/coro/completion.h +++ b/vespalib/src/vespa/vespalib/coro/completion.h @@ -13,7 +13,7 @@ namespace vespalib::coro { -// Resume/start the coroutine responsible for calculating the result +// Resume (start) the coroutine responsible for calculating the result // and signal the receiver when it completes or fails. Note that the // detached coroutine will own both the coroutine calculating the // result and the receiver that is later notified of the result. The @@ -24,15 +24,15 @@ namespace vespalib::coro { // execution where the coroutine represented by Lazy<T> is the // sender. Execution parameters can be encapsulated inside Lazy<T> // using composition (for example which executor should run the -// coroutine). +// coroutine). The receiver in this context may be either an actual +// receiver_of<T>, a callback function accepting a Received<T> or an +// std::promise. The different cases are handled by the overloaded +// Recieved<T>::forward function template. template <typename T, typename R> Detached connect_resume(Lazy<T> value, R receiver) { - try { - receiver.set_value(co_await std::move(value)); - } catch (...) { - receiver.set_error(std::current_exception()); - } + auto&& result = co_await value.forward(); + result.forward(receiver); } // replace Lazy<T> with std::future<T> to be able to synchronously @@ -40,45 +40,12 @@ Detached connect_resume(Lazy<T> value, R receiver) { template <typename T> std::future<T> make_future(Lazy<T> value) { - struct receiver { - std::promise<T> promise; - receiver() : promise() {} - void set_value(T value) { - promise.set_value(std::move(value)); - } - void set_error(std::exception_ptr error) { - promise.set_exception(error); - } - }; - receiver my_receiver; - auto future = my_receiver.promise.get_future(); - connect_resume(std::move(value), std::move(my_receiver)); + std::promise<T> promise; + auto future = promise.get_future(); + connect_resume(std::move(value), std::move(promise)); return future; } -// Create a receiver from a function object (typically a lambda -// closure) that takes a received value (stored receiver result) as -// its only parameter. - -template <typename T, typename F> -auto make_receiver(F &&f) { - struct receiver { - Received<T> result; - std::decay_t<F> fun; - receiver(F &&f) - : result(), fun(std::forward<F>(f)) {} - void set_value(T value) { - result.set_value(std::move(value)); - fun(std::move(result)); - } - void set_error(std::exception_ptr why) { - result.set_error(why); - fun(std::move(result)); - } - }; - return receiver(std::forward<F>(f)); -} - /** * Wait for a lazy value to be calculated synchronously. Make sure the * thread waiting is not needed in the calculation of the value, or @@ -93,12 +60,11 @@ T sync_wait(Lazy<T> value) { * Wait for a lazy value to be calculated asynchronously; the provided * callback will be called with a Received<T> when the Lazy<T> is * done. Both the callback itself and the Lazy<T> will be destructed - * afterwards; cleaning up the coroutine tree representing the - * calculation. + * afterwards. **/ template <typename T, typename F> void async_wait(Lazy<T> value, F &&f) { - connect_resume(std::move(value), make_receiver<T>(std::forward<F>(f))); + connect_resume(std::move(value), std::forward<F>(f)); } } diff --git a/vespalib/src/vespa/vespalib/coro/lazy.h b/vespalib/src/vespa/vespalib/coro/lazy.h index 144b5c945f0..04e6b9c8835 100644 --- a/vespalib/src/vespa/vespalib/coro/lazy.h +++ b/vespalib/src/vespa/vespalib/coro/lazy.h @@ -2,6 +2,8 @@ #pragma once +#include "received.h" + #include <concepts> #include <coroutine> #include <optional> @@ -39,69 +41,53 @@ public: return awaiter(); } template <typename RET> - requires std::is_convertible_v<RET&&,T> - void return_value(RET &&ret_value) noexcept(std::is_nothrow_constructible_v<T,RET&&>) { - value = std::forward<RET>(ret_value); + void return_value(RET &&ret_value) { + result.set_value(std::forward<RET>(ret_value)); } void unhandled_exception() noexcept { - exception = std::current_exception(); + result.set_error(std::current_exception()); } - std::optional<T> value; - std::exception_ptr exception; + Received<T> result; std::coroutine_handle<> waiter; promise_type(promise_type &&) = delete; promise_type(const promise_type &) = delete; - promise_type() noexcept : value(std::nullopt), exception(), waiter(std::noop_coroutine()) {} - T &result() & { - if (exception) { - std::rethrow_exception(exception); - } - return *value; - } - T &&result() && { - if (exception) { - std::rethrow_exception(exception); - } - return std::move(*value); - } + promise_type() noexcept : result(), waiter(std::noop_coroutine()) {} ~promise_type(); }; using Handle = std::coroutine_handle<promise_type>; private: Handle _handle; - - struct awaiter_base { + + template <typename RET> + struct WaitFor { Handle handle; - awaiter_base(Handle handle_in) noexcept : handle(handle_in) {} + WaitFor(Handle handle_in) noexcept : handle(handle_in) {} bool await_ready() const noexcept { return handle.done(); } Handle await_suspend(std::coroutine_handle<> waiter) const noexcept { handle.promise().waiter = waiter; return handle; } + decltype(auto) await_resume() const { return RET::get(handle.promise()); } + }; + struct LValue { + static T& get(auto &&promise) { return promise.result.get_value(); } }; - + struct RValue { + static T&& get(auto &&promise) { return std::move(promise.result.get_value()); } + }; + struct Result { + static Received<T>&& get(auto &&promise) { return std::move(promise.result); } + }; + public: Lazy(const Lazy &) = delete; Lazy &operator=(const Lazy &) = delete; explicit Lazy(Handle handle_in) noexcept : _handle(handle_in) {} Lazy(Lazy &&rhs) noexcept : _handle(std::exchange(rhs._handle, nullptr)) {} - auto operator co_await() & noexcept { - struct awaiter : awaiter_base { - using awaiter_base::handle; - awaiter(Handle handle_in) noexcept : awaiter_base(handle_in) {} - decltype(auto) await_resume() const { return handle.promise().result(); } - }; - return awaiter(_handle); - } - auto operator co_await() && noexcept { - struct awaiter : awaiter_base { - using awaiter_base::handle; - awaiter(Handle handle_in) noexcept : awaiter_base(handle_in) {} - decltype(auto) await_resume() const { return std::move(handle.promise()).result(); } - }; - return awaiter(_handle); - } + auto operator co_await() & noexcept { return WaitFor<LValue>(_handle); } + auto operator co_await() && noexcept { return WaitFor<RValue>(_handle); } + auto forward() noexcept { return WaitFor<Result>(_handle); } ~Lazy() { if (_handle) { _handle.destroy(); diff --git a/vespalib/src/vespa/vespalib/coro/received.h b/vespalib/src/vespa/vespalib/coro/received.h index 4f2efddcfa1..abc66cd2a9d 100644 --- a/vespalib/src/vespa/vespalib/coro/received.h +++ b/vespalib/src/vespa/vespalib/coro/received.h @@ -2,9 +2,11 @@ #pragma once +#include <memory> #include <variant> #include <exception> #include <stdexcept> +#include <future> namespace vespalib::coro { @@ -12,6 +14,20 @@ struct UnavailableResultException : std::runtime_error { using std::runtime_error::runtime_error; }; +// concept indicating that R may be used to receive T +template <typename R, typename T> +concept receiver_of = requires(R r, T t, std::exception_ptr e) { + r.set_value(std::move(t)); + r.set_error(e); + r.set_done(); +}; + +// concept indicating that R is a completion callback accepting T +template <typename R, typename T> +concept completion_callback_for = requires(R r, T t) { + r(std::move(t)); +}; + /** * Simple value wrapper that stores the result observed by a receiver * (value/error/done). A receiver is the continuation of an @@ -21,26 +37,61 @@ template <std::movable T> class Received { private: std::variant<std::exception_ptr,T> _value; + std::exception_ptr normalize_error() const { + if (auto ex = std::get<0>(_value)) { + return ex; + } else { + return std::make_exception_ptr(UnavailableResultException("tried to access the result of a canceled operation")); + } + } public: Received() : _value() {} - void set_value(T value) { _value.template emplace<1>(std::move(value)); } + void set_value(T &&value) { _value.template emplace<1>(std::move(value)); } + void set_value(const T &value) { _value.template emplace<1>(value); } void set_error(std::exception_ptr exception) { _value.template emplace<0>(exception); } void set_done() { _value.template emplace<0>(nullptr); } bool has_value() const { return (_value.index() == 1); } bool has_error() const { return (_value.index() == 0) && bool(std::get<0>(_value)); } bool was_canceled() const { return !has_value() && !has_error(); } - std::exception_ptr get_error() const { return has_error() ? std::get<0>(_value) : std::exception_ptr(); } - T get_value() { + std::exception_ptr get_error() const { + return has_value() ? std::exception_ptr() : std::get<0>(_value); + } + T &get_value() & { + if (_value.index() == 1) { + return std::get<1>(_value); + } else { + std::rethrow_exception(normalize_error()); + } + } + T &&get_value() && { return std::move(get_value()); } + template <typename R> + requires completion_callback_for<R,Received> + void forward(R &&r) { + r(std::move(*this)); + } + template <typename R> + requires receiver_of<R,T> + void forward(R &r) { if (_value.index() == 1) { - return std::move(std::get<1>(_value)); + r.set_value(std::get<1>(std::move(_value))); } else { if (auto ex = std::get<0>(_value)) { - std::rethrow_exception(ex); + r.set_error(ex); } else { - throw UnavailableResultException("tried to access the result of a canceled operation"); + r.set_done(); } } } + void forward(std::promise<T> &r) { + if (_value.index() == 1) { + r.set_value(std::get<1>(std::move(_value))); + } else { + r.set_exception(normalize_error()); + } + } }; +static_assert(receiver_of<Received<int>, int>); +static_assert(receiver_of<Received<std::unique_ptr<int>>, std::unique_ptr<int>>); + } |