diff options
author | HÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com> | 2020-01-07 12:25:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-07 12:25:45 +0100 |
commit | 59b58e2da8bf846d078d227ddd477d86f474ba1a (patch) | |
tree | 0d4c3b31bd022b5deba49efed27177915922a79d | |
parent | 60f0c5cc078e67b5d5d9f22eb4e8e9067816e77c (diff) | |
parent | 496a11fb34253292b1496da440742b388c72163b (diff) |
Merge pull request #11664 from vespa-engine/havardpe/overload
Havardpe/overload
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.cpp | 65 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.h | 16 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp | 20 | ||||
-rw-r--r-- | vespalib/CMakeLists.txt | 2 | ||||
-rw-r--r-- | vespalib/src/tests/overload/CMakeLists.txt | 9 | ||||
-rw-r--r-- | vespalib/src/tests/overload/overload_test.cpp | 19 | ||||
-rw-r--r-- | vespalib/src/tests/visit_ranges/CMakeLists.txt | 9 | ||||
-rw-r--r-- | vespalib/src/tests/visit_ranges/visit_ranges_test.cpp | 120 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/overload.h | 15 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/visit_ranges.h | 81 |
10 files changed, 314 insertions, 42 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index e98a56a0148..9e429757982 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -164,20 +164,23 @@ void op_tensor_peek(State &state, uint64_t param) { TensorSpec::Address addr; size_t child_cnt = 0; for (auto pos = self.spec().rbegin(); pos != self.spec().rend(); ++pos) { - if (std::holds_alternative<TensorSpec::Label>(pos->second)) { - addr.emplace(pos->first, std::get<TensorSpec::Label>(pos->second)); - } else { - assert(std::holds_alternative<TensorFunction::Child>(pos->second)); - double index = round(state.peek(child_cnt++).as_double()); - size_t dim_idx = self.param_type().dimension_index(pos->first); - assert(dim_idx != ValueType::Dimension::npos); - const auto ¶m_dim = self.param_type().dimensions()[dim_idx]; - if (param_dim.is_mapped()) { - addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index))); - } else { - addr.emplace(pos->first, size_t(index)); - } - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + addr.emplace(pos->first, label); + }, + [&](const TensorFunction::Child &) { + double index = round(state.peek(child_cnt++).as_double()); + size_t dim_idx = self.param_type().dimension_index(pos->first); + assert(dim_idx != ValueType::Dimension::npos); + const auto ¶m_dim = self.param_type().dimensions()[dim_idx]; + if (param_dim.is_mapped()) { + addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index))); + } else { + addr.emplace(pos->first, size_t(index)); + } + } + }, pos->second); } TensorSpec spec = state.engine.to_spec(state.peek(child_cnt++)); const Value &result = self.result_type().is_double() @@ -364,9 +367,13 @@ Peek::push_children(std::vector<Child::CREF> &children) const { children.emplace_back(_param); for (const auto &dim: _spec) { - if (std::holds_alternative<Child>(dim.second)) { - children.emplace_back(std::get<Child>(dim.second)); - } + std::visit(vespalib::overload + { + [&](const Child &child) { + children.emplace_back(child); + }, + [](const TensorSpec::Label &){} + }, dim.second); } } @@ -381,17 +388,19 @@ Peek::visit_children(vespalib::ObjectVisitor &visitor) const { ::visit(visitor, "param", _param.get()); for (const auto &dim: _spec) { - if (std::holds_alternative<TensorSpec::Label>(dim.second)) { - const auto &label = std::get<TensorSpec::Label>(dim.second); - if (label.is_mapped()) { - ::visit(visitor, dim.first, label.name); - } else { - ::visit(visitor, dim.first, label.index); - } - } else { - assert(std::holds_alternative<Child>(dim.second)); - ::visit(visitor, dim.first, std::get<Child>(dim.second).get()); - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + if (label.is_mapped()) { + ::visit(visitor, dim.first, label.name); + } else { + ::visit(visitor, dim.first, label.index); + } + }, + [&](const Child &child) { + ::visit(visitor, dim.first, child.get()); + } + }, dim.second); } } diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index c95ffd17bbe..b019ab64e18 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -8,6 +8,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/stllike/string.h> #include <vespa/vespalib/util/arrayref.h> +#include <vespa/vespalib/util/overload.h> #include "tensor_spec.h" #include "lazy_params.h" #include "value_type.h" @@ -320,12 +321,15 @@ public: : Node(result_type_in), _param(param), _spec() { for (const auto &dim: spec) { - if (std::holds_alternative<TensorSpec::Label>(dim.second)) { - _spec.emplace(dim.first, std::get<TensorSpec::Label>(dim.second)); - } else { - assert(std::holds_alternative<Node::CREF>(dim.second)); - _spec.emplace(dim.first, std::get<Node::CREF>(dim.second).get()); - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + _spec.emplace(dim.first, label); + }, + [&](const Node::CREF &ref) { + _spec.emplace(dim.first, ref.get()); + } + }, dim.second); } } const std::map<vespalib::string, MyLabel> &spec() const { return _spec; } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp index f9d15a377e9..2920ca26234 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp @@ -2,6 +2,7 @@ #include "dense_tensor_peek_function.h" #include "dense_tensor_view.h" +#include <vespa/vespalib/util/overload.h> #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/tensor/tensor.h> @@ -87,14 +88,17 @@ DenseTensorPeekFunction::optimize(const eval::TensorFunction &expr, Stash &stash for (auto dim = peek_type.dimensions().rbegin(); dim != peek_type.dimensions().rend(); ++dim) { auto dim_spec = peek->spec().find(dim->name); assert(dim_spec != peek->spec().end()); - if (std::holds_alternative<TensorSpec::Label>(dim_spec->second)) { - const auto &label = std::get<TensorSpec::Label>(dim_spec->second); - assert(label.is_indexed()); - spec.emplace_back(label.index, dim->size); - } else { - assert(std::holds_alternative<TensorFunction::Child>(dim_spec->second)); - spec.emplace_back(-1, dim->size); - } + + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + assert(label.is_indexed()); + spec.emplace_back(label.index, dim->size); + }, + [&](const TensorFunction::Child &) { + spec.emplace_back(-1, dim->size); + } + }, dim_spec->second); } return stash.create<DenseTensorPeekFunction>(peek->copy_children(), spec); } diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 14c14fe85ca..9339cdacea0 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -78,6 +78,7 @@ vespa_define_module( src/tests/net/tls/transport_options src/tests/objects/nbostream src/tests/optimized + src/tests/overload src/tests/portal src/tests/portal/handle_manager src/tests/portal/http_request @@ -130,6 +131,7 @@ vespa_define_module( src/tests/util/md5 src/tests/util/rcuvector src/tests/valgrind + src/tests/visit_ranges src/tests/websocket src/tests/zcurve diff --git a/vespalib/src/tests/overload/CMakeLists.txt b/vespalib/src/tests/overload/CMakeLists.txt new file mode 100644 index 00000000000..67aa6230225 --- /dev/null +++ b/vespalib/src/tests/overload/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_overload_test_app TEST + SOURCES + overload_test.cpp + DEPENDS + vespalib + gtest +) +vespa_add_test(NAME vespalib_overload_test_app COMMAND vespalib_overload_test_app) diff --git a/vespalib/src/tests/overload/overload_test.cpp b/vespalib/src/tests/overload/overload_test.cpp new file mode 100644 index 00000000000..ceae29ac02f --- /dev/null +++ b/vespalib/src/tests/overload/overload_test.cpp @@ -0,0 +1,19 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/overload.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <variant> +#include <string> + +using namespace vespalib; + +TEST(OverloadTest, visit_with_overload_works) { + std::variant<std::string,int> a = 10; + std::variant<std::string,int> b = "foo"; + std::visit(overload{[](int v){ EXPECT_EQ(v,10); }, + [](const std::string &){ FAIL() << "invalid visit"; }}, a); + std::visit(overload{[](int){ FAIL() << "invalid visit"; }, + [](const std::string &v){ EXPECT_EQ(v, "foo"); }}, b); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/visit_ranges/CMakeLists.txt b/vespalib/src/tests/visit_ranges/CMakeLists.txt new file mode 100644 index 00000000000..de94b2ebb1e --- /dev/null +++ b/vespalib/src/tests/visit_ranges/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_visit_ranges_test_app TEST + SOURCES + visit_ranges_test.cpp + DEPENDS + vespalib + gtest +) +vespa_add_test(NAME vespalib_visit_ranges_test_app COMMAND vespalib_visit_ranges_test_app) diff --git a/vespalib/src/tests/visit_ranges/visit_ranges_test.cpp b/vespalib/src/tests/visit_ranges/visit_ranges_test.cpp new file mode 100644 index 00000000000..cb1ac9bd9c3 --- /dev/null +++ b/vespalib/src/tests/visit_ranges/visit_ranges_test.cpp @@ -0,0 +1,120 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/overload.h> +#include <vespa/vespalib/util/visit_ranges.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <variant> +#include <string> + +using namespace vespalib; + +TEST(VisitRangeExample, set_intersection) { + std::vector<int> first({1,3,7}); + std::vector<int> second({2,3,8}); + std::vector<int> result; + vespalib::visit_ranges(overload{[](visit_ranges_either, int) {}, + [&result](visit_ranges_both, int x, int) { result.push_back(x); }}, + first.begin(), first.end(), second.begin(), second.end()); + EXPECT_EQ(result, std::vector<int>({3})); +} + +TEST(VisitRangeExample, set_subtraction) { + std::vector<int> first({1,3,7}); + std::vector<int> second({2,3,8}); + std::vector<int> result; + vespalib::visit_ranges(overload{[&result](visit_ranges_first, int a) { result.push_back(a); }, + [](visit_ranges_second, int) {}, + [](visit_ranges_both, int, int) {}}, + first.begin(), first.end(), second.begin(), second.end()); + EXPECT_EQ(result, std::vector<int>({1,7})); +} + +TEST(VisitRangesTest, empty_ranges_can_be_visited) { + std::vector<int> a; + std::vector<int> b; + std::vector<int> c; + auto visitor = overload + { + [&c](visit_ranges_either, int) { + c.push_back(42); + }, + [&c](visit_ranges_both, int, int) { + c.push_back(42); + } + }; + vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end()); + EXPECT_EQ(c, std::vector<int>({})); +} + +TEST(VisitRangesTest, simple_merge_can_be_implemented) { + std::vector<int> a({1,3,7}); + std::vector<int> b({2,3,8}); + std::vector<int> c; + auto visitor = overload + { + [&c](visit_ranges_either, int x) { + c.push_back(x); + }, + [&c](visit_ranges_both, int x, int y) { + c.push_back(x); + c.push_back(y); + } + }; + vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end()); + EXPECT_EQ(c, std::vector<int>({1,2,3,3,7,8})); +} + +TEST(VisitRangesTest, simple_union_can_be_implemented) { + std::vector<int> a({1,3,7}); + std::vector<int> b({2,3,8}); + std::vector<int> c; + auto visitor = overload + { + [&c](visit_ranges_either, int x) { + c.push_back(x); + }, + [&c](visit_ranges_both, int x, int) { + c.push_back(x); + } + }; + vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end()); + EXPECT_EQ(c, std::vector<int>({1,2,3,7,8})); +} + +TEST(VisitRangesTest, asymmetric_merge_can_be_implemented) { + std::vector<int> a({1,3,7}); + std::vector<int> b({2,3,8}); + std::vector<int> c; + auto visitor = overload + { + [&c](visit_ranges_first, int x) { + c.push_back(x); + }, + [&c](visit_ranges_second, int) {}, + [&c](visit_ranges_both, int x, int y) { + c.push_back(x * y); + } + }; + vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end()); + EXPECT_EQ(c, std::vector<int>({1,9,7})); +} + +TEST(VisitRangesTest, comparator_can_be_specified) { + std::vector<int> a({7,3,1}); + std::vector<int> b({8,3,2}); + std::vector<int> c; + auto visitor = overload + { + [&c](visit_ranges_either, int x) { + c.push_back(x); + }, + [&c](visit_ranges_both, int x, int y) { + c.push_back(x); + c.push_back(y); + } + }; + vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end(), std::greater<>()); + EXPECT_EQ(c, std::vector<int>({8,7,3,3,2,1})); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/util/overload.h b/vespalib/src/vespa/vespalib/util/overload.h new file mode 100644 index 00000000000..d5af9dee2d3 --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/overload.h @@ -0,0 +1,15 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib { + +/** + * Simple overload lambda composition class. To be replaced by + * standard overload functionality when available. (C++20) + **/ + +template<class... Ts> struct overload : Ts... { using Ts::operator()...; }; +template<class... Ts> overload(Ts...) -> overload<Ts...>; + +} diff --git a/vespalib/src/vespa/vespalib/util/visit_ranges.h b/vespalib/src/vespa/vespalib/util/visit_ranges.h new file mode 100644 index 00000000000..9cf1112fe0d --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/visit_ranges.h @@ -0,0 +1,81 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <functional> + +namespace vespalib { + +struct visit_ranges_either {}; +struct visit_ranges_first : visit_ranges_either {}; +struct visit_ranges_second : visit_ranges_either {}; +struct visit_ranges_both {}; + +/** + * Visit elements from two distinct ranges in the order defined by the + * given comparator. The comparator must define a strict-weak ordering + * across all elements from both ranges and each range must already be + * sorted according to the comparator before calling this + * function. Pairs of elements from the two ranges (one from each) + * that are equal according to the comparator will be visited by a + * single callback. The different cases ('from the first range', 'from + * the second range' and 'from both ranges') are indicated by using + * tagged dispatch in the visitation callback. + * + * An example treating both inputs equally: + * <pre> + * TEST(VisitRangeExample, set_intersection) { + * std::vector<int> first({1,3,7}); + * std::vector<int> second({2,3,8}); + * std::vector<int> result; + * vespalib::visit_ranges(overload{[](visit_ranges_either, int) {}, + * [&result](visit_ranges_both, int x, int) { result.push_back(x); }}, + * first.begin(), first.end(), second.begin(), second.end()); + * EXPECT_EQ(result, std::vector<int>({3})); + * } + * </pre> + * + * An example treating the inputs differently: + * <pre> + * TEST(VisitRangeExample, set_subtraction) { + * std::vector<int> first({1,3,7}); + * std::vector<int> second({2,3,8}); + * std::vector<int> result; + * vespalib::visit_ranges(overload{[&result](visit_ranges_first, int a) { result.push_back(a); }, + * [](visit_ranges_second, int) {}, + * [](visit_ranges_both, int, int) {}}, + * first.begin(), first.end(), second.begin(), second.end()); + * EXPECT_EQ(result, std::vector<int>({1,7})); + * } + * </pre> + * + * The intention of this function is to simplify the implementation of + * merge-like operations. + **/ + +template <typename V, typename ItA, typename ItB, typename Cmp = std::less<> > +void visit_ranges(V &&visitor, ItA pos_a, ItA end_a, ItB pos_b, ItB end_b, Cmp cmp = Cmp()) { + while ((pos_a != end_a) && (pos_b != end_b)) { + if (cmp(*pos_a, *pos_b)) { + visitor(visit_ranges_first(), *pos_a); + ++pos_a; + } else if (cmp(*pos_b, *pos_a)) { + visitor(visit_ranges_second(), *pos_b); + ++pos_b; + } else { + visitor(visit_ranges_both(), *pos_a, *pos_b); + ++pos_a; + ++pos_b; + } + } + while (pos_a != end_a) { + visitor(visit_ranges_first(), *pos_a); + ++pos_a; + } + while (pos_b != end_b) { + visitor(visit_ranges_second(), *pos_b); + ++pos_b; + } +} + +} |