diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2022-10-18 13:11:06 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2022-10-19 12:10:23 +0000 |
commit | d994d291dd376c03c86dc3e4f18c66022db111cc (patch) | |
tree | c8eecf9c086d3b39c8114722f733f7dce06968ea /vespalib/src | |
parent | 7a8571355a737aab02934ad5cf9fc8521c429b54 (diff) |
generator coroutine return value
Diffstat (limited to 'vespalib/src')
-rw-r--r-- | vespalib/src/tests/coro/generator/CMakeLists.txt | 9 | ||||
-rw-r--r-- | vespalib/src/tests/coro/generator/generator_test.cpp | 225 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/coro/generator.h | 159 |
3 files changed, 393 insertions, 0 deletions
diff --git a/vespalib/src/tests/coro/generator/CMakeLists.txt b/vespalib/src/tests/coro/generator/CMakeLists.txt new file mode 100644 index 00000000000..b4f59c69451 --- /dev/null +++ b/vespalib/src/tests/coro/generator/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_generator_test_app TEST + SOURCES + generator_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_generator_test_app COMMAND vespalib_generator_test_app) diff --git a/vespalib/src/tests/coro/generator/generator_test.cpp b/vespalib/src/tests/coro/generator/generator_test.cpp new file mode 100644 index 00000000000..149fe379faa --- /dev/null +++ b/vespalib/src/tests/coro/generator/generator_test.cpp @@ -0,0 +1,225 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/coro/lazy.h> +#include <vespa/vespalib/coro/generator.h> +#include <vespa/vespalib/util/require.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <ranges> +#include <vector> + +using vespalib::coro::Lazy; +using vespalib::coro::Generator; + +class Unmovable { +private: + int _value; +public: + Unmovable() = delete; + Unmovable &operator=(const Unmovable &) = delete; + Unmovable(const Unmovable &) = delete; + Unmovable &operator=(Unmovable &&) = delete; + Unmovable(Unmovable &&) = delete; + Unmovable(int value) : _value(value) {} + int get() const { return _value; } +}; + +static_assert(std::input_iterator<Generator<int>::Iterator>); +static_assert(std::ranges::input_range<Generator<int>>); + +Lazy<int> foo() { co_return 0; } + +Generator<int> make_numbers(int begin, int end) { + // co_yield co_await foo(); + for (int i = begin; i < end; ++i) { + co_yield i; + } +} + +Generator<int> make_numbers(int begin, int split, int end) { + co_yield make_numbers(begin, split); + co_yield make_numbers(split, end); +} + +static_assert(std::input_iterator<Generator<std::unique_ptr<int>>::Iterator>); +static_assert(std::ranges::input_range<Generator<std::unique_ptr<int>>>); + +Generator<std::unique_ptr<int>> make_movable(int begin, int end) { + for (int i = begin; i < end; ++i) { + co_yield std::make_unique<int>(i); + } +} + +static_assert(std::input_iterator<Generator<Unmovable>::Iterator>); +static_assert(std::ranges::input_range<Generator<Unmovable>>); + +Generator<Unmovable> make_unmovable(int begin, int end) { + for (int i = begin; i < end; ++i) { + co_yield Unmovable(i); + } +} + +Generator<int> make_failed_numbers(int begin, int end, int fail) { + for (int i = begin; i < end; ++i) { + REQUIRE(i != fail); + co_yield i; + } +} + +Generator<int> make_safe(Generator<int> gen) { + try { + co_yield gen; + } catch (...) {} +} + +Generator<int> a_then_b(Generator<int> a, Generator<int> b) { + co_yield a; + co_yield b; +} + +TEST(GeneratorTest, generate_some_numbers) { + auto gen = make_numbers(1, 4); + auto pos = gen.begin(); + auto end = gen.end(); + ASSERT_FALSE(pos == end); + EXPECT_EQ(*pos, 1); + ++pos; + ASSERT_FALSE(pos == end); + EXPECT_EQ(*pos, 2); + ++pos; + ASSERT_FALSE(pos == end); + EXPECT_EQ(*pos, 3); + ++pos; + EXPECT_TRUE(pos == end); +} + +TEST(GeneratorTest, generate_no_numbers) { + auto gen = make_numbers(1, 1); + auto pos = gen.begin(); + auto end = gen.end(); + EXPECT_TRUE(pos == end); +} + +TEST(GeneratorTest, generate_movable_values) { + auto gen = make_movable(1,4); + std::vector<std::unique_ptr<int>> res; + for(auto pos = gen.begin(); pos != gen.end(); ++pos) { + res.push_back(*pos); + } + ASSERT_EQ(res.size(), 3); + EXPECT_EQ(*res[0], 1); + EXPECT_EQ(*res[1], 2); + EXPECT_EQ(*res[2], 3); +} + +TEST(GeneratorTest, generate_unmovable_values) { + auto gen = make_unmovable(1,4); + auto pos = gen.begin(); + auto end = gen.end(); + ASSERT_FALSE(pos == end); + EXPECT_EQ(pos->get(), 1); + ++pos; + ASSERT_FALSE(pos == end); + EXPECT_EQ(pos->get(), 2); + ++pos; + ASSERT_FALSE(pos == end); + EXPECT_EQ(pos->get(), 3); + ++pos; + EXPECT_TRUE(pos == end); +} + +TEST(GeneratorTest, range_based_for_loop) { + int expect = 1; + for (int x: make_numbers(1, 10)) { + EXPECT_EQ(x, expect); + ++expect; + } + EXPECT_EQ(expect, 10); +} + +TEST(GeneratorTest, explicit_range_for_loop) { + int expect = 1; + auto gen = make_numbers(1, 10); + auto pos = std::ranges::begin(gen); + auto end = std::ranges::end(gen); + for (; pos != end; ++pos) { + EXPECT_EQ(*pos, expect); + ++expect; + } + EXPECT_EQ(expect, 10); +} + +TEST(GeneratorTest, recursive_generator) { + int expect = 1; + for (int x: make_numbers(1, 4, 10)) { + EXPECT_EQ(x, expect); + ++expect; + } + EXPECT_EQ(expect, 10); +} + +TEST(GeneratorTest, deeper_recursive_generator) { + int expect = 1; + for (int x: a_then_b(make_numbers(1, 3, 5), make_numbers(5, 7, 10))) { + EXPECT_EQ(x, expect); + ++expect; + } + EXPECT_EQ(expect, 10); +} + +TEST(GeneratorTest, simple_exception) { + auto gen = make_failed_numbers(1, 10, 5); + auto pos = std::ranges::begin(gen); + auto end = std::ranges::end(gen); + EXPECT_EQ(*pos, 1); + EXPECT_EQ(*++pos, 2); + EXPECT_EQ(*++pos, 3); + EXPECT_EQ(*++pos, 4); + EXPECT_FALSE(pos == end); + EXPECT_THROW(++pos, vespalib::RequireFailedException); + EXPECT_TRUE(pos == end); +} + +TEST(GeneratorTest, forwarded_exception) { + auto gen = a_then_b(make_failed_numbers(1, 10, 5), make_numbers(10, 20)); + auto pos = std::ranges::begin(gen); + auto end = std::ranges::end(gen); + EXPECT_EQ(*pos, 1); + EXPECT_EQ(*++pos, 2); + EXPECT_EQ(*++pos, 3); + EXPECT_EQ(*++pos, 4); + EXPECT_FALSE(pos == end); + EXPECT_THROW(++pos, vespalib::RequireFailedException); + EXPECT_TRUE(pos == end); +} + +TEST(GeneratorTest, exception_captured_by_parent_generator) { + int expect = 1; + for (int x: a_then_b(make_safe(make_failed_numbers(1, 10, 5)), make_numbers(5, 10))) { + EXPECT_EQ(x, expect); + ++expect; + } + EXPECT_EQ(expect, 10); +} + +TEST(GeneratorTest, moving_iterator_with_recursive_generator) { + auto gen = a_then_b(make_numbers(1, 3, 5), make_numbers(5, 7, 9)); + auto pos = std::ranges::begin(gen); + auto end = std::ranges::end(gen); + EXPECT_EQ(*pos, 1); + EXPECT_EQ(*++pos, 2); + auto pos2 = std::move(pos); + EXPECT_EQ(*++pos2, 3); + EXPECT_EQ(*++pos2, 4); + auto pos3 = std::move(pos2); + EXPECT_EQ(*++pos3, 5); + EXPECT_EQ(*++pos3, 6); + auto pos4 = std::move(pos3); + EXPECT_EQ(*++pos4, 7); + EXPECT_EQ(*++pos4, 8); + auto pos5 = std::move(pos4); + EXPECT_FALSE(pos5 == end); + ++pos5; + EXPECT_TRUE(pos5 == end); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/coro/generator.h b/vespalib/src/vespa/vespalib/coro/generator.h new file mode 100644 index 00000000000..01cf0931094 --- /dev/null +++ b/vespalib/src/vespa/vespalib/coro/generator.h @@ -0,0 +1,159 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <concepts> +#include <coroutine> +#include <exception> +#include <utility> +#include <cstddef> +#include <iterator> + +namespace vespalib::coro { + +/** + * coroutine return type + * + * The coroutine is lazy (will suspend in initial_suspend) and + * destroyed from the outside (will suspend in final_suspend). A + * generator may produce any number of results using co_yield, but + * cannot use co_await (it must be synchronous). The values produced + * by the generator is accessed by using the generator as an + * input_range. A generator is recursive (it may yield another + * generator of the same type to include its values in the output). + **/ +template <typename T, typename ValueType = std::remove_cvref<T>> +class [[nodiscard]] Generator { +public: + using value_type = ValueType; + using Pointer = std::add_pointer_t<T>; + + class promise_type; + using Handle = std::coroutine_handle<promise_type>; + + class promise_type { + private: + Pointer _ptr; + std::exception_ptr _exception; + Handle *_itr_state; + Handle _parent; + + template <bool check_exception> + struct SwitchTo : std::suspend_always { + Handle next; + explicit SwitchTo(Handle next_in) : next(next_in) {} + std::coroutine_handle<> await_suspend(Handle prev) const noexcept { + if (next) { + Handle &itr_state = prev.promise().itr_state(); + itr_state = next; + next.promise().itr_state(itr_state); + return next; + } else { + return std::noop_coroutine(); + } + } + void await_resume() const noexcept(!check_exception) { + if (check_exception && next.promise()._exception) { + std::rethrow_exception(next.promise()._exception); + } + } + }; + + public: + promise_type(promise_type &&) = delete; + promise_type(const promise_type &) = delete; + promise_type() noexcept : _ptr(nullptr), _exception(), _itr_state(nullptr), _parent(nullptr) {} + Generator<T> get_return_object() { return Generator(Handle::from_promise(*this)); } + std::suspend_always initial_suspend() noexcept { return {}; } + auto final_suspend() noexcept { return SwitchTo<false>(_parent); } + std::suspend_always yield_value(T &&value) { + _ptr = &value; + return {}; + } + auto yield_value(const T &value) requires(!std::is_reference_v<T> && std::copy_constructible<T>) { + struct awaiter : std::suspend_always { + awaiter(const T &value, Pointer &ptr) : value_cpy(value) { + ptr = std::addressof(value_cpy); + } + awaiter(awaiter&&) = delete; + T value_cpy; + }; + return awaiter(value, _ptr); + } + auto yield_value(Generator &&child) { return yield_value(child); } + auto yield_value(Generator &child) { + child._handle.promise()._parent = Handle::from_promise(*this); + return SwitchTo<true>(child._handle); + } + void return_void() { _ptr = nullptr; } + void unhandled_exception() { + if (_parent) { + _exception = std::current_exception(); + } else { + throw; + } + } + T &&result() { + return std::forward<T>(*_ptr); + } + Pointer result_ptr() { + return _ptr; + } + Handle &itr_state() const noexcept { return *_itr_state; } + void itr_state(Handle &handle) noexcept { _itr_state = std::addressof(handle); } + template<typename U> std::suspend_always await_transform(U &&value) = delete; + }; + + class Iterator { + private: + Handle _handle; + public: + Iterator() noexcept : _handle(nullptr) {} + Iterator(Iterator &&rhs) noexcept = default; + Iterator &operator=(Iterator &&rhs) noexcept = default; + Iterator(const Iterator &rhs) = delete; + Iterator &operator=(const Iterator &) = delete; + explicit Iterator(Handle handle) : _handle(handle) { + _handle.promise().itr_state(_handle); + _handle.resume(); + } + using iterator_concept = std::input_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = std::remove_cvref_t<T>; + bool operator==(std::default_sentinel_t) const { + return _handle.done(); + } + Iterator &operator++() { + _handle.promise().itr_state(_handle); + _handle.resume(); + return *this; + } + void operator++(int) { + operator++(); + } + decltype(auto) operator*() const { + return std::forward<T>(_handle.promise().result()); + } + auto operator->() const { + return _handle.promise().result_ptr(); + } + }; + +private: + Handle _handle; + +public: + Generator(const Generator &) = delete; + Generator &operator=(const Generator &) = delete; + explicit Generator(Handle handle_in) noexcept : _handle(handle_in) {} + Generator(Generator &&rhs) noexcept : _handle(std::exchange(rhs._handle, nullptr)) {} + ~Generator() { + if (_handle) { + _handle.destroy(); + } + } + auto begin() { return Iterator(_handle); } + auto end() const noexcept { return std::default_sentinel_t(); } +}; + +} |