summaryrefslogtreecommitdiffstats
path: root/vespalib/src
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2022-10-18 13:11:06 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2022-10-19 12:10:23 +0000
commitd994d291dd376c03c86dc3e4f18c66022db111cc (patch)
treec8eecf9c086d3b39c8114722f733f7dce06968ea /vespalib/src
parent7a8571355a737aab02934ad5cf9fc8521c429b54 (diff)
generator coroutine return value
Diffstat (limited to 'vespalib/src')
-rw-r--r--vespalib/src/tests/coro/generator/CMakeLists.txt9
-rw-r--r--vespalib/src/tests/coro/generator/generator_test.cpp225
-rw-r--r--vespalib/src/vespa/vespalib/coro/generator.h159
3 files changed, 393 insertions, 0 deletions
diff --git a/vespalib/src/tests/coro/generator/CMakeLists.txt b/vespalib/src/tests/coro/generator/CMakeLists.txt
new file mode 100644
index 00000000000..b4f59c69451
--- /dev/null
+++ b/vespalib/src/tests/coro/generator/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(vespalib_generator_test_app TEST
+ SOURCES
+ generator_test.cpp
+ DEPENDS
+ vespalib
+ GTest::GTest
+)
+vespa_add_test(NAME vespalib_generator_test_app COMMAND vespalib_generator_test_app)
diff --git a/vespalib/src/tests/coro/generator/generator_test.cpp b/vespalib/src/tests/coro/generator/generator_test.cpp
new file mode 100644
index 00000000000..149fe379faa
--- /dev/null
+++ b/vespalib/src/tests/coro/generator/generator_test.cpp
@@ -0,0 +1,225 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/coro/lazy.h>
+#include <vespa/vespalib/coro/generator.h>
+#include <vespa/vespalib/util/require.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <ranges>
+#include <vector>
+
+using vespalib::coro::Lazy;
+using vespalib::coro::Generator;
+
+class Unmovable {
+private:
+ int _value;
+public:
+ Unmovable() = delete;
+ Unmovable &operator=(const Unmovable &) = delete;
+ Unmovable(const Unmovable &) = delete;
+ Unmovable &operator=(Unmovable &&) = delete;
+ Unmovable(Unmovable &&) = delete;
+ Unmovable(int value) : _value(value) {}
+ int get() const { return _value; }
+};
+
+static_assert(std::input_iterator<Generator<int>::Iterator>);
+static_assert(std::ranges::input_range<Generator<int>>);
+
+Lazy<int> foo() { co_return 0; }
+
+Generator<int> make_numbers(int begin, int end) {
+ // co_yield co_await foo();
+ for (int i = begin; i < end; ++i) {
+ co_yield i;
+ }
+}
+
+Generator<int> make_numbers(int begin, int split, int end) {
+ co_yield make_numbers(begin, split);
+ co_yield make_numbers(split, end);
+}
+
+static_assert(std::input_iterator<Generator<std::unique_ptr<int>>::Iterator>);
+static_assert(std::ranges::input_range<Generator<std::unique_ptr<int>>>);
+
+Generator<std::unique_ptr<int>> make_movable(int begin, int end) {
+ for (int i = begin; i < end; ++i) {
+ co_yield std::make_unique<int>(i);
+ }
+}
+
+static_assert(std::input_iterator<Generator<Unmovable>::Iterator>);
+static_assert(std::ranges::input_range<Generator<Unmovable>>);
+
+Generator<Unmovable> make_unmovable(int begin, int end) {
+ for (int i = begin; i < end; ++i) {
+ co_yield Unmovable(i);
+ }
+}
+
+Generator<int> make_failed_numbers(int begin, int end, int fail) {
+ for (int i = begin; i < end; ++i) {
+ REQUIRE(i != fail);
+ co_yield i;
+ }
+}
+
+Generator<int> make_safe(Generator<int> gen) {
+ try {
+ co_yield gen;
+ } catch (...) {}
+}
+
+Generator<int> a_then_b(Generator<int> a, Generator<int> b) {
+ co_yield a;
+ co_yield b;
+}
+
+TEST(GeneratorTest, generate_some_numbers) {
+ auto gen = make_numbers(1, 4);
+ auto pos = gen.begin();
+ auto end = gen.end();
+ ASSERT_FALSE(pos == end);
+ EXPECT_EQ(*pos, 1);
+ ++pos;
+ ASSERT_FALSE(pos == end);
+ EXPECT_EQ(*pos, 2);
+ ++pos;
+ ASSERT_FALSE(pos == end);
+ EXPECT_EQ(*pos, 3);
+ ++pos;
+ EXPECT_TRUE(pos == end);
+}
+
+TEST(GeneratorTest, generate_no_numbers) {
+ auto gen = make_numbers(1, 1);
+ auto pos = gen.begin();
+ auto end = gen.end();
+ EXPECT_TRUE(pos == end);
+}
+
+TEST(GeneratorTest, generate_movable_values) {
+ auto gen = make_movable(1,4);
+ std::vector<std::unique_ptr<int>> res;
+ for(auto pos = gen.begin(); pos != gen.end(); ++pos) {
+ res.push_back(*pos);
+ }
+ ASSERT_EQ(res.size(), 3);
+ EXPECT_EQ(*res[0], 1);
+ EXPECT_EQ(*res[1], 2);
+ EXPECT_EQ(*res[2], 3);
+}
+
+TEST(GeneratorTest, generate_unmovable_values) {
+ auto gen = make_unmovable(1,4);
+ auto pos = gen.begin();
+ auto end = gen.end();
+ ASSERT_FALSE(pos == end);
+ EXPECT_EQ(pos->get(), 1);
+ ++pos;
+ ASSERT_FALSE(pos == end);
+ EXPECT_EQ(pos->get(), 2);
+ ++pos;
+ ASSERT_FALSE(pos == end);
+ EXPECT_EQ(pos->get(), 3);
+ ++pos;
+ EXPECT_TRUE(pos == end);
+}
+
+TEST(GeneratorTest, range_based_for_loop) {
+ int expect = 1;
+ for (int x: make_numbers(1, 10)) {
+ EXPECT_EQ(x, expect);
+ ++expect;
+ }
+ EXPECT_EQ(expect, 10);
+}
+
+TEST(GeneratorTest, explicit_range_for_loop) {
+ int expect = 1;
+ auto gen = make_numbers(1, 10);
+ auto pos = std::ranges::begin(gen);
+ auto end = std::ranges::end(gen);
+ for (; pos != end; ++pos) {
+ EXPECT_EQ(*pos, expect);
+ ++expect;
+ }
+ EXPECT_EQ(expect, 10);
+}
+
+TEST(GeneratorTest, recursive_generator) {
+ int expect = 1;
+ for (int x: make_numbers(1, 4, 10)) {
+ EXPECT_EQ(x, expect);
+ ++expect;
+ }
+ EXPECT_EQ(expect, 10);
+}
+
+TEST(GeneratorTest, deeper_recursive_generator) {
+ int expect = 1;
+ for (int x: a_then_b(make_numbers(1, 3, 5), make_numbers(5, 7, 10))) {
+ EXPECT_EQ(x, expect);
+ ++expect;
+ }
+ EXPECT_EQ(expect, 10);
+}
+
+TEST(GeneratorTest, simple_exception) {
+ auto gen = make_failed_numbers(1, 10, 5);
+ auto pos = std::ranges::begin(gen);
+ auto end = std::ranges::end(gen);
+ EXPECT_EQ(*pos, 1);
+ EXPECT_EQ(*++pos, 2);
+ EXPECT_EQ(*++pos, 3);
+ EXPECT_EQ(*++pos, 4);
+ EXPECT_FALSE(pos == end);
+ EXPECT_THROW(++pos, vespalib::RequireFailedException);
+ EXPECT_TRUE(pos == end);
+}
+
+TEST(GeneratorTest, forwarded_exception) {
+ auto gen = a_then_b(make_failed_numbers(1, 10, 5), make_numbers(10, 20));
+ auto pos = std::ranges::begin(gen);
+ auto end = std::ranges::end(gen);
+ EXPECT_EQ(*pos, 1);
+ EXPECT_EQ(*++pos, 2);
+ EXPECT_EQ(*++pos, 3);
+ EXPECT_EQ(*++pos, 4);
+ EXPECT_FALSE(pos == end);
+ EXPECT_THROW(++pos, vespalib::RequireFailedException);
+ EXPECT_TRUE(pos == end);
+}
+
+TEST(GeneratorTest, exception_captured_by_parent_generator) {
+ int expect = 1;
+ for (int x: a_then_b(make_safe(make_failed_numbers(1, 10, 5)), make_numbers(5, 10))) {
+ EXPECT_EQ(x, expect);
+ ++expect;
+ }
+ EXPECT_EQ(expect, 10);
+}
+
+TEST(GeneratorTest, moving_iterator_with_recursive_generator) {
+ auto gen = a_then_b(make_numbers(1, 3, 5), make_numbers(5, 7, 9));
+ auto pos = std::ranges::begin(gen);
+ auto end = std::ranges::end(gen);
+ EXPECT_EQ(*pos, 1);
+ EXPECT_EQ(*++pos, 2);
+ auto pos2 = std::move(pos);
+ EXPECT_EQ(*++pos2, 3);
+ EXPECT_EQ(*++pos2, 4);
+ auto pos3 = std::move(pos2);
+ EXPECT_EQ(*++pos3, 5);
+ EXPECT_EQ(*++pos3, 6);
+ auto pos4 = std::move(pos3);
+ EXPECT_EQ(*++pos4, 7);
+ EXPECT_EQ(*++pos4, 8);
+ auto pos5 = std::move(pos4);
+ EXPECT_FALSE(pos5 == end);
+ ++pos5;
+ EXPECT_TRUE(pos5 == end);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/vespa/vespalib/coro/generator.h b/vespalib/src/vespa/vespalib/coro/generator.h
new file mode 100644
index 00000000000..01cf0931094
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/coro/generator.h
@@ -0,0 +1,159 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <concepts>
+#include <coroutine>
+#include <exception>
+#include <utility>
+#include <cstddef>
+#include <iterator>
+
+namespace vespalib::coro {
+
+/**
+ * coroutine return type
+ *
+ * The coroutine is lazy (will suspend in initial_suspend) and
+ * destroyed from the outside (will suspend in final_suspend). A
+ * generator may produce any number of results using co_yield, but
+ * cannot use co_await (it must be synchronous). The values produced
+ * by the generator is accessed by using the generator as an
+ * input_range. A generator is recursive (it may yield another
+ * generator of the same type to include its values in the output).
+ **/
+template <typename T, typename ValueType = std::remove_cvref<T>>
+class [[nodiscard]] Generator {
+public:
+ using value_type = ValueType;
+ using Pointer = std::add_pointer_t<T>;
+
+ class promise_type;
+ using Handle = std::coroutine_handle<promise_type>;
+
+ class promise_type {
+ private:
+ Pointer _ptr;
+ std::exception_ptr _exception;
+ Handle *_itr_state;
+ Handle _parent;
+
+ template <bool check_exception>
+ struct SwitchTo : std::suspend_always {
+ Handle next;
+ explicit SwitchTo(Handle next_in) : next(next_in) {}
+ std::coroutine_handle<> await_suspend(Handle prev) const noexcept {
+ if (next) {
+ Handle &itr_state = prev.promise().itr_state();
+ itr_state = next;
+ next.promise().itr_state(itr_state);
+ return next;
+ } else {
+ return std::noop_coroutine();
+ }
+ }
+ void await_resume() const noexcept(!check_exception) {
+ if (check_exception && next.promise()._exception) {
+ std::rethrow_exception(next.promise()._exception);
+ }
+ }
+ };
+
+ public:
+ promise_type(promise_type &&) = delete;
+ promise_type(const promise_type &) = delete;
+ promise_type() noexcept : _ptr(nullptr), _exception(), _itr_state(nullptr), _parent(nullptr) {}
+ Generator<T> get_return_object() { return Generator(Handle::from_promise(*this)); }
+ std::suspend_always initial_suspend() noexcept { return {}; }
+ auto final_suspend() noexcept { return SwitchTo<false>(_parent); }
+ std::suspend_always yield_value(T &&value) {
+ _ptr = &value;
+ return {};
+ }
+ auto yield_value(const T &value) requires(!std::is_reference_v<T> && std::copy_constructible<T>) {
+ struct awaiter : std::suspend_always {
+ awaiter(const T &value, Pointer &ptr) : value_cpy(value) {
+ ptr = std::addressof(value_cpy);
+ }
+ awaiter(awaiter&&) = delete;
+ T value_cpy;
+ };
+ return awaiter(value, _ptr);
+ }
+ auto yield_value(Generator &&child) { return yield_value(child); }
+ auto yield_value(Generator &child) {
+ child._handle.promise()._parent = Handle::from_promise(*this);
+ return SwitchTo<true>(child._handle);
+ }
+ void return_void() { _ptr = nullptr; }
+ void unhandled_exception() {
+ if (_parent) {
+ _exception = std::current_exception();
+ } else {
+ throw;
+ }
+ }
+ T &&result() {
+ return std::forward<T>(*_ptr);
+ }
+ Pointer result_ptr() {
+ return _ptr;
+ }
+ Handle &itr_state() const noexcept { return *_itr_state; }
+ void itr_state(Handle &handle) noexcept { _itr_state = std::addressof(handle); }
+ template<typename U> std::suspend_always await_transform(U &&value) = delete;
+ };
+
+ class Iterator {
+ private:
+ Handle _handle;
+ public:
+ Iterator() noexcept : _handle(nullptr) {}
+ Iterator(Iterator &&rhs) noexcept = default;
+ Iterator &operator=(Iterator &&rhs) noexcept = default;
+ Iterator(const Iterator &rhs) = delete;
+ Iterator &operator=(const Iterator &) = delete;
+ explicit Iterator(Handle handle) : _handle(handle) {
+ _handle.promise().itr_state(_handle);
+ _handle.resume();
+ }
+ using iterator_concept = std::input_iterator_tag;
+ using difference_type = std::ptrdiff_t;
+ using value_type = std::remove_cvref_t<T>;
+ bool operator==(std::default_sentinel_t) const {
+ return _handle.done();
+ }
+ Iterator &operator++() {
+ _handle.promise().itr_state(_handle);
+ _handle.resume();
+ return *this;
+ }
+ void operator++(int) {
+ operator++();
+ }
+ decltype(auto) operator*() const {
+ return std::forward<T>(_handle.promise().result());
+ }
+ auto operator->() const {
+ return _handle.promise().result_ptr();
+ }
+ };
+
+private:
+ Handle _handle;
+
+public:
+ Generator(const Generator &) = delete;
+ Generator &operator=(const Generator &) = delete;
+ explicit Generator(Handle handle_in) noexcept : _handle(handle_in) {}
+ Generator(Generator &&rhs) noexcept : _handle(std::exchange(rhs._handle, nullptr)) {}
+ ~Generator() {
+ if (_handle) {
+ _handle.destroy();
+ }
+ }
+ auto begin() { return Iterator(_handle); }
+ auto end() const noexcept { return std::default_sentinel_t(); }
+};
+
+}