aboutsummaryrefslogtreecommitdiffstats
path: root/vespalib/src/vespa/vespalib/test/nexus.h
blob: 7be35d424633310166ef8711478cfa9d97367de8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include "thread_meets.h"
#include <vespa/vespalib/util/thread.h>
#include <vespa/vespalib/util/require.h>
#include <optional>
#include <variant>

namespace vespalib::test {

class Nexus;
template <typename T>
concept nexus_thread_entry = requires(Nexus &ctx, T &&entry) {
    entry(ctx);
};

/**
 * Utility intended to make it easier to write multi-threaded code for
 * testing and benchmarking.
 **/
class Nexus
{
private:
    using vote_t = vespalib::test::ThreadMeets::Vote;
    vote_t &_vote;
    size_t _thread_id;
    Nexus(vote_t &vote, size_t thread_id) noexcept
      : _vote(vote), _thread_id(thread_id) {}
    ~Nexus();
public:
    Nexus(Nexus &&) = delete;
    Nexus(const Nexus &) = delete;
    Nexus &operator=(Nexus &&) = delete;
    Nexus &operator=(const Nexus &) = delete;
    size_t num_threads() const noexcept { return _vote.size(); }
    size_t thread_id() const noexcept { return _thread_id; }
    bool is_main() const noexcept { return _thread_id == 0; }
    bool vote(bool my_vote) { return _vote(my_vote); }
    void barrier() { REQUIRE_EQ(_vote(true), true); }
    struct select_thread_0 {};
    constexpr static auto merge_sum() { return [](auto a, auto b){ return a + b; }; }
    static auto run(size_t num_threads, auto &&entry, auto &&merge) requires nexus_thread_entry<decltype(entry)> {
        ThreadPool pool;
        vote_t vote(num_threads);
        using result_t = std::decay_t<decltype(entry(std::declval<Nexus&>()))>;
        constexpr bool is_void = std::same_as<result_t, void>;
        using stored_t = std::conditional_t<is_void, std::monostate, result_t>;
        std::mutex lock;
        std::optional<stored_t> result;
        auto handle_result = [&](Nexus &ctx, stored_t thread_result) noexcept {
            if constexpr (std::same_as<std::decay_t<decltype(merge)>,select_thread_0>) {
                if (ctx.thread_id() == 0) {
                    result = std::move(thread_result);
                }
            } else {
                std::lock_guard guard(lock);
                if (result.has_value()) {
                    result = merge(std::move(result).value(),
                                   std::move(thread_result));
                } else {
                    result = std::move(thread_result);
                }
            }
        };
        auto thread_main = [&](size_t thread_id) noexcept {
            Nexus ctx(vote, thread_id);
            if constexpr (is_void) {
                entry(ctx);
            } else {
                handle_result(ctx, entry(ctx));
            }
        };
        for (size_t i = 1; i < num_threads; ++i) {
            pool.start([i,&thread_main]() noexcept { thread_main(i); });
        }
        thread_main(0);
        pool.join();
        if constexpr (!is_void) {
            return std::move(result).value();
        }
    }
    static auto run(size_t num_threads, auto &&entry) requires nexus_thread_entry<decltype(entry)> {
        return run(num_threads, std::forward<decltype(entry)>(entry), select_thread_0{});
    }
};

}