aboutsummaryrefslogtreecommitdiffstats
path: root/vespalib
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-01-25 15:55:43 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-01-26 16:23:17 +0000
commit7c4fd0b043f0cc37bad278d31c9a3c4c324d7b23 (patch)
tree9f18108dd714dd69e9db101c411333408908a052 /vespalib
parenteec4962db2f6dc302f7195c1b20a6818dbc0178a (diff)
track coroutines waiting for values
Diffstat (limited to 'vespalib')
-rw-r--r--vespalib/CMakeLists.txt1
-rw-r--r--vespalib/src/tests/coro/waiting_for/CMakeLists.txt9
-rw-r--r--vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp110
-rw-r--r--vespalib/src/vespa/vespalib/coro/lazy.h12
-rw-r--r--vespalib/src/vespa/vespalib/coro/received.h5
-rw-r--r--vespalib/src/vespa/vespalib/coro/waiting_for.h108
6 files changed, 235 insertions, 10 deletions
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt
index 8509d5fc382..76308260578 100644
--- a/vespalib/CMakeLists.txt
+++ b/vespalib/CMakeLists.txt
@@ -50,6 +50,7 @@ vespa_define_module(
src/tests/coro/generator
src/tests/coro/lazy
src/tests/coro/received
+ src/tests/coro/waiting_for
src/tests/cpu_usage
src/tests/crc
src/tests/crypto
diff --git a/vespalib/src/tests/coro/waiting_for/CMakeLists.txt b/vespalib/src/tests/coro/waiting_for/CMakeLists.txt
new file mode 100644
index 00000000000..d9eaa7eaf03
--- /dev/null
+++ b/vespalib/src/tests/coro/waiting_for/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(vespalib_waiting_for_test_app TEST
+ SOURCES
+ waiting_for_test.cpp
+ DEPENDS
+ vespalib
+ GTest::GTest
+)
+vespa_add_test(NAME vespalib_waiting_for_test_app COMMAND vespalib_waiting_for_test_app)
diff --git a/vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp b/vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp
new file mode 100644
index 00000000000..385d4ad24e3
--- /dev/null
+++ b/vespalib/src/tests/coro/waiting_for/waiting_for_test.cpp
@@ -0,0 +1,110 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/coro/lazy.h>
+#include <vespa/vespalib/coro/completion.h>
+#include <vespa/vespalib/coro/waiting_for.h>
+#include <vespa/vespalib/util/time.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::coro;
+
+struct AsyncService {
+ std::vector<WaitingFor<int>> pending;
+ auto get_value() {
+ return awaiter_for<int>([&](WaitingFor<int> handle)
+ {
+ pending.push_back(std::move(handle));
+ });
+ }
+};
+
+struct AsyncVoidService {
+ std::vector<void*> pending;
+ auto get_value() {
+ return awaiter_for<int>([&](WaitingFor<int> handle)
+ {
+ pending.push_back(handle.release());
+ });
+ }
+};
+
+struct SyncService {
+ auto get_value() {
+ return awaiter_for<int>([](WaitingFor<int> handle)
+ {
+ handle.set_value(42);
+ return handle.release_waiter(); // symmetric transfer
+ });
+ }
+};
+
+template<typename Service>
+Lazy<int> wait_for_value(Service &service) {
+ int value = co_await service.get_value();
+ co_return value;
+}
+
+template <typename T>
+Lazy<T> wait_any(auto &&fun) {
+ T result = co_await fun();
+ co_return std::move(result);
+}
+
+TEST(WaitingForTest, wait_for_external_async_int) {
+ AsyncService service;
+ auto res = make_future(wait_for_value(service));
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout);
+ ASSERT_EQ(service.pending.size(), 1);
+ service.pending[0].set_value(42);
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout);
+ service.pending.clear();
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready);
+ EXPECT_EQ(res.get(), 42);
+}
+
+TEST(WaitingForTest, wait_for_external_async_int_via_void_ptr) {
+ AsyncVoidService service;
+ auto res = make_future(wait_for_value(service));
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout);
+ ASSERT_EQ(service.pending.size(), 1);
+ {
+ auto handle = WaitingFor<int>::from_pointer(service.pending[0]);
+ handle.set_value(42);
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::timeout);
+ }
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready);
+ EXPECT_EQ(res.get(), 42);
+}
+
+TEST(WaitingForTest, wait_for_external_sync_int) {
+ SyncService service;
+ auto res = make_future(wait_for_value(service));
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready);
+ EXPECT_EQ(res.get(), 42);
+}
+
+TEST(WaitingForTest, wait_for_move_only_value) {
+ auto val = std::make_unique<int>(42);
+ auto fun = [&val](auto handle){ handle.set_value(std::move(val)); }; // asymmetric transfer
+ auto res = make_future(wait_any<decltype(val)>([&fun](){ return awaiter_for<decltype(val)>(fun); }));
+ EXPECT_TRUE(res.wait_for(0ms) == std::future_status::ready);
+ EXPECT_EQ(*res.get(), 42);
+}
+
+TEST(WaitingForTest, set_error) {
+ PromiseState<int> state;
+ WaitingFor<int> pending = WaitingFor<int>::from_state(state);
+ pending.set_error(std::make_exception_ptr(13));
+ EXPECT_TRUE(state.result.has_error());
+}
+
+TEST(WaitingForTest, set_done) {
+ PromiseState<int> state;
+ WaitingFor<int> pending = WaitingFor<int>::from_state(state);
+ pending.set_value(5);
+ EXPECT_TRUE(state.result.has_value());
+ pending.set_done();
+ EXPECT_TRUE(state.result.was_canceled());
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/vespa/vespalib/coro/lazy.h b/vespalib/src/vespa/vespalib/coro/lazy.h
index 974968d0c77..17077dccc9f 100644
--- a/vespalib/src/vespa/vespalib/coro/lazy.h
+++ b/vespalib/src/vespa/vespalib/coro/lazy.h
@@ -2,9 +2,8 @@
#pragma once
-#include "received.h"
+#include "waiting_for.h"
-#include <concepts>
#include <coroutine>
#include <optional>
#include <exception>
@@ -27,7 +26,8 @@ namespace vespalib::coro {
template <std::movable T>
class [[nodiscard]] Lazy {
public:
- struct promise_type {
+ struct promise_type final : PromiseState<T> {
+ using PromiseState<T>::result;
Lazy<T> get_return_object() { return Lazy(Handle::from_promise(*this)); }
static std::suspend_always initial_suspend() noexcept { return {}; }
static auto final_suspend() noexcept {
@@ -47,11 +47,7 @@ public:
void unhandled_exception() noexcept {
result.set_error(std::current_exception());
}
- Received<T> result;
- std::coroutine_handle<> waiter;
- promise_type(promise_type &&) = delete;
- promise_type(const promise_type &) = delete;
- promise_type() noexcept : result(), waiter(std::noop_coroutine()) {}
+ promise_type() noexcept : PromiseState<T>() {}
~promise_type();
};
using Handle = std::coroutine_handle<promise_type>;
diff --git a/vespalib/src/vespa/vespalib/coro/received.h b/vespalib/src/vespa/vespalib/coro/received.h
index abc66cd2a9d..305a187249c 100644
--- a/vespalib/src/vespa/vespalib/coro/received.h
+++ b/vespalib/src/vespa/vespalib/coro/received.h
@@ -3,6 +3,7 @@
#pragma once
#include <memory>
+#include <concepts>
#include <variant>
#include <exception>
#include <stdexcept>
@@ -46,8 +47,8 @@ private:
}
public:
Received() : _value() {}
- void set_value(T &&value) { _value.template emplace<1>(std::move(value)); }
- void set_value(const T &value) { _value.template emplace<1>(value); }
+ template <typename RET>
+ void set_value(RET &&value) { _value.template emplace<1>(std::forward<RET>(value)); }
void set_error(std::exception_ptr exception) { _value.template emplace<0>(exception); }
void set_done() { _value.template emplace<0>(nullptr); }
bool has_value() const { return (_value.index() == 1); }
diff --git a/vespalib/src/vespa/vespalib/coro/waiting_for.h b/vespalib/src/vespa/vespalib/coro/waiting_for.h
new file mode 100644
index 00000000000..2e11a9cb38c
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/coro/waiting_for.h
@@ -0,0 +1,108 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "received.h"
+#include <coroutine>
+#include <utility>
+
+namespace vespalib::coro {
+
+// State representing that someone (waiter) is waiting for something
+// (result). This object cannot be moved or copied.
+template <typename T>
+struct PromiseState {
+ Received<T> result;
+ std::coroutine_handle<> waiter;
+ PromiseState(const PromiseState &) = delete;
+ PromiseState &operator=(const PromiseState &) = delete;
+ PromiseState(PromiseState &&) = delete;
+ PromiseState &operator=(PromiseState &&) = delete;
+ PromiseState() noexcept : result(), waiter(std::noop_coroutine()) {}
+ ~PromiseState();
+};
+template <typename T>
+PromiseState<T>::~PromiseState() = default;
+
+// A thin (smart) wrapper referencing a PromiseState<T> representing
+// that a coroutine is waiting for a value. This class acts as a
+// receiver in order to set the result value. When the owning
+// reference is deleted, the waiting coroutine will be resumed.
+template <typename T>
+class WaitingFor {
+private:
+ PromiseState<T> *_state;
+ WaitingFor(PromiseState<T> *state) noexcept : _state(state) {}
+public:
+ WaitingFor(WaitingFor &&rhs) noexcept : _state(std::exchange(rhs._state, nullptr)) {}
+ WaitingFor(WaitingFor &rhs) = delete;
+ WaitingFor &operator=(WaitingFor &rhs) = delete;
+ ~WaitingFor();
+ template <typename RET>
+ void set_value(RET &&value) {
+ _state->result.set_value(std::forward<RET>(value));
+ }
+ void set_error(std::exception_ptr exception) {
+ _state->result.set_error(exception);
+ }
+ void set_done() {
+ _state->result.set_done();
+ }
+ std::coroutine_handle<> release_waiter() {
+ return std::exchange(_state->waiter, std::noop_coroutine());
+ }
+ void *release() {
+ return std::exchange(_state, nullptr);
+ }
+ static WaitingFor from_pointer(void *ptr) {
+ PromiseState<T> *state = reinterpret_cast<PromiseState<T>*>(ptr);
+ return {state};
+ }
+ static WaitingFor from_state(PromiseState<T> &state) {
+ return {&state};
+ }
+};
+
+template <typename T>
+WaitingFor<T>::~WaitingFor()
+{
+ if (_state != nullptr) {
+ _state->waiter.resume();
+ }
+}
+
+static_assert(receiver_of<WaitingFor<int>, int>);
+static_assert(receiver_of<WaitingFor<std::unique_ptr<int>>, std::unique_ptr<int>>);
+
+// Create a custom awaiter that will return a value of type T when the
+// coroutine is resumed. The waiting coroutine will be represented as
+// a WaitingFor<T> that is passed as the only parameter to 'f'. The
+// return value of 'f' is returned from await_suspend, which means it
+// must be void, bool or coroutine handle. If 'f' returns a value
+// indicating that the coroutine should be resumed immediately,
+// WaitingFor<T>::release_waiter() must be called to avoid resume
+// being called as well. Note that await_ready will always return
+// false, since the coroutine needs to be suspended in order to create
+// the WaitingFor<T> object needed. Also, the WaitingFor<T> api
+// implies that the value will be set from the outside and thus cannot
+// be ready up-front. Also note that await_resume must return T by
+// value, since the awaiter containing the result is a temporary
+// object.
+template <typename T, typename F>
+auto awaiter_for(F &&f) {
+ struct awaiter final : PromiseState<T> {
+ using PromiseState<T>::result;
+ using PromiseState<T>::waiter;
+ std::decay_t<F> fun;
+ awaiter(F &&f) : PromiseState<T>(), fun(std::forward<F>(f)) {}
+ bool await_ready() const noexcept { return false; }
+ T await_resume() { return std::move(result).get_value(); }
+ decltype(auto) await_suspend(std::coroutine_handle<> handle) __attribute__((noinline)) {
+ waiter = handle;
+ return fun(WaitingFor<T>::from_state(*this));
+ }
+ };
+ return awaiter(std::forward<F>(f));
+}
+
+}