aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2020-01-07 12:25:45 +0100
committerGitHub <noreply@github.com>2020-01-07 12:25:45 +0100
commit59b58e2da8bf846d078d227ddd477d86f474ba1a (patch)
tree0d4c3b31bd022b5deba49efed27177915922a79d
parent60f0c5cc078e67b5d5d9f22eb4e8e9067816e77c (diff)
parent496a11fb34253292b1496da440742b388c72163b (diff)
Merge pull request #11664 from vespa-engine/havardpe/overload
Havardpe/overload
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp65
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h16
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp20
-rw-r--r--vespalib/CMakeLists.txt2
-rw-r--r--vespalib/src/tests/overload/CMakeLists.txt9
-rw-r--r--vespalib/src/tests/overload/overload_test.cpp19
-rw-r--r--vespalib/src/tests/visit_ranges/CMakeLists.txt9
-rw-r--r--vespalib/src/tests/visit_ranges/visit_ranges_test.cpp120
-rw-r--r--vespalib/src/vespa/vespalib/util/overload.h15
-rw-r--r--vespalib/src/vespa/vespalib/util/visit_ranges.h81
10 files changed, 314 insertions, 42 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index e98a56a0148..9e429757982 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -164,20 +164,23 @@ void op_tensor_peek(State &state, uint64_t param) {
TensorSpec::Address addr;
size_t child_cnt = 0;
for (auto pos = self.spec().rbegin(); pos != self.spec().rend(); ++pos) {
- if (std::holds_alternative<TensorSpec::Label>(pos->second)) {
- addr.emplace(pos->first, std::get<TensorSpec::Label>(pos->second));
- } else {
- assert(std::holds_alternative<TensorFunction::Child>(pos->second));
- double index = round(state.peek(child_cnt++).as_double());
- size_t dim_idx = self.param_type().dimension_index(pos->first);
- assert(dim_idx != ValueType::Dimension::npos);
- const auto &param_dim = self.param_type().dimensions()[dim_idx];
- if (param_dim.is_mapped()) {
- addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index)));
- } else {
- addr.emplace(pos->first, size_t(index));
- }
- }
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ addr.emplace(pos->first, label);
+ },
+ [&](const TensorFunction::Child &) {
+ double index = round(state.peek(child_cnt++).as_double());
+ size_t dim_idx = self.param_type().dimension_index(pos->first);
+ assert(dim_idx != ValueType::Dimension::npos);
+ const auto &param_dim = self.param_type().dimensions()[dim_idx];
+ if (param_dim.is_mapped()) {
+ addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index)));
+ } else {
+ addr.emplace(pos->first, size_t(index));
+ }
+ }
+ }, pos->second);
}
TensorSpec spec = state.engine.to_spec(state.peek(child_cnt++));
const Value &result = self.result_type().is_double()
@@ -364,9 +367,13 @@ Peek::push_children(std::vector<Child::CREF> &children) const
{
children.emplace_back(_param);
for (const auto &dim: _spec) {
- if (std::holds_alternative<Child>(dim.second)) {
- children.emplace_back(std::get<Child>(dim.second));
- }
+ std::visit(vespalib::overload
+ {
+ [&](const Child &child) {
+ children.emplace_back(child);
+ },
+ [](const TensorSpec::Label &){}
+ }, dim.second);
}
}
@@ -381,17 +388,19 @@ Peek::visit_children(vespalib::ObjectVisitor &visitor) const
{
::visit(visitor, "param", _param.get());
for (const auto &dim: _spec) {
- if (std::holds_alternative<TensorSpec::Label>(dim.second)) {
- const auto &label = std::get<TensorSpec::Label>(dim.second);
- if (label.is_mapped()) {
- ::visit(visitor, dim.first, label.name);
- } else {
- ::visit(visitor, dim.first, label.index);
- }
- } else {
- assert(std::holds_alternative<Child>(dim.second));
- ::visit(visitor, dim.first, std::get<Child>(dim.second).get());
- }
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ if (label.is_mapped()) {
+ ::visit(visitor, dim.first, label.name);
+ } else {
+ ::visit(visitor, dim.first, label.index);
+ }
+ },
+ [&](const Child &child) {
+ ::visit(visitor, dim.first, child.get());
+ }
+ }, dim.second);
}
}
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index c95ffd17bbe..b019ab64e18 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -8,6 +8,7 @@
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/stllike/string.h>
#include <vespa/vespalib/util/arrayref.h>
+#include <vespa/vespalib/util/overload.h>
#include "tensor_spec.h"
#include "lazy_params.h"
#include "value_type.h"
@@ -320,12 +321,15 @@ public:
: Node(result_type_in), _param(param), _spec()
{
for (const auto &dim: spec) {
- if (std::holds_alternative<TensorSpec::Label>(dim.second)) {
- _spec.emplace(dim.first, std::get<TensorSpec::Label>(dim.second));
- } else {
- assert(std::holds_alternative<Node::CREF>(dim.second));
- _spec.emplace(dim.first, std::get<Node::CREF>(dim.second).get());
- }
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ _spec.emplace(dim.first, label);
+ },
+ [&](const Node::CREF &ref) {
+ _spec.emplace(dim.first, ref.get());
+ }
+ }, dim.second);
}
}
const std::map<vespalib::string, MyLabel> &spec() const { return _spec; }
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp
index f9d15a377e9..2920ca26234 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp
@@ -2,6 +2,7 @@
#include "dense_tensor_peek_function.h"
#include "dense_tensor_view.h"
+#include <vespa/vespalib/util/overload.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/tensor/tensor.h>
@@ -87,14 +88,17 @@ DenseTensorPeekFunction::optimize(const eval::TensorFunction &expr, Stash &stash
for (auto dim = peek_type.dimensions().rbegin(); dim != peek_type.dimensions().rend(); ++dim) {
auto dim_spec = peek->spec().find(dim->name);
assert(dim_spec != peek->spec().end());
- if (std::holds_alternative<TensorSpec::Label>(dim_spec->second)) {
- const auto &label = std::get<TensorSpec::Label>(dim_spec->second);
- assert(label.is_indexed());
- spec.emplace_back(label.index, dim->size);
- } else {
- assert(std::holds_alternative<TensorFunction::Child>(dim_spec->second));
- spec.emplace_back(-1, dim->size);
- }
+
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ assert(label.is_indexed());
+ spec.emplace_back(label.index, dim->size);
+ },
+ [&](const TensorFunction::Child &) {
+ spec.emplace_back(-1, dim->size);
+ }
+ }, dim_spec->second);
}
return stash.create<DenseTensorPeekFunction>(peek->copy_children(), spec);
}
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt
index 14c14fe85ca..9339cdacea0 100644
--- a/vespalib/CMakeLists.txt
+++ b/vespalib/CMakeLists.txt
@@ -78,6 +78,7 @@ vespa_define_module(
src/tests/net/tls/transport_options
src/tests/objects/nbostream
src/tests/optimized
+ src/tests/overload
src/tests/portal
src/tests/portal/handle_manager
src/tests/portal/http_request
@@ -130,6 +131,7 @@ vespa_define_module(
src/tests/util/md5
src/tests/util/rcuvector
src/tests/valgrind
+ src/tests/visit_ranges
src/tests/websocket
src/tests/zcurve
diff --git a/vespalib/src/tests/overload/CMakeLists.txt b/vespalib/src/tests/overload/CMakeLists.txt
new file mode 100644
index 00000000000..67aa6230225
--- /dev/null
+++ b/vespalib/src/tests/overload/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(vespalib_overload_test_app TEST
+ SOURCES
+ overload_test.cpp
+ DEPENDS
+ vespalib
+ gtest
+)
+vespa_add_test(NAME vespalib_overload_test_app COMMAND vespalib_overload_test_app)
diff --git a/vespalib/src/tests/overload/overload_test.cpp b/vespalib/src/tests/overload/overload_test.cpp
new file mode 100644
index 00000000000..ceae29ac02f
--- /dev/null
+++ b/vespalib/src/tests/overload/overload_test.cpp
@@ -0,0 +1,19 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/util/overload.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <variant>
+#include <string>
+
+using namespace vespalib;
+
+TEST(OverloadTest, visit_with_overload_works) {
+ std::variant<std::string,int> a = 10;
+ std::variant<std::string,int> b = "foo";
+ std::visit(overload{[](int v){ EXPECT_EQ(v,10); },
+ [](const std::string &){ FAIL() << "invalid visit"; }}, a);
+ std::visit(overload{[](int){ FAIL() << "invalid visit"; },
+ [](const std::string &v){ EXPECT_EQ(v, "foo"); }}, b);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/tests/visit_ranges/CMakeLists.txt b/vespalib/src/tests/visit_ranges/CMakeLists.txt
new file mode 100644
index 00000000000..de94b2ebb1e
--- /dev/null
+++ b/vespalib/src/tests/visit_ranges/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(vespalib_visit_ranges_test_app TEST
+ SOURCES
+ visit_ranges_test.cpp
+ DEPENDS
+ vespalib
+ gtest
+)
+vespa_add_test(NAME vespalib_visit_ranges_test_app COMMAND vespalib_visit_ranges_test_app)
diff --git a/vespalib/src/tests/visit_ranges/visit_ranges_test.cpp b/vespalib/src/tests/visit_ranges/visit_ranges_test.cpp
new file mode 100644
index 00000000000..cb1ac9bd9c3
--- /dev/null
+++ b/vespalib/src/tests/visit_ranges/visit_ranges_test.cpp
@@ -0,0 +1,120 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/util/overload.h>
+#include <vespa/vespalib/util/visit_ranges.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <variant>
+#include <string>
+
+using namespace vespalib;
+
+TEST(VisitRangeExample, set_intersection) {
+ std::vector<int> first({1,3,7});
+ std::vector<int> second({2,3,8});
+ std::vector<int> result;
+ vespalib::visit_ranges(overload{[](visit_ranges_either, int) {},
+ [&result](visit_ranges_both, int x, int) { result.push_back(x); }},
+ first.begin(), first.end(), second.begin(), second.end());
+ EXPECT_EQ(result, std::vector<int>({3}));
+}
+
+TEST(VisitRangeExample, set_subtraction) {
+ std::vector<int> first({1,3,7});
+ std::vector<int> second({2,3,8});
+ std::vector<int> result;
+ vespalib::visit_ranges(overload{[&result](visit_ranges_first, int a) { result.push_back(a); },
+ [](visit_ranges_second, int) {},
+ [](visit_ranges_both, int, int) {}},
+ first.begin(), first.end(), second.begin(), second.end());
+ EXPECT_EQ(result, std::vector<int>({1,7}));
+}
+
+TEST(VisitRangesTest, empty_ranges_can_be_visited) {
+ std::vector<int> a;
+ std::vector<int> b;
+ std::vector<int> c;
+ auto visitor = overload
+ {
+ [&c](visit_ranges_either, int) {
+ c.push_back(42);
+ },
+ [&c](visit_ranges_both, int, int) {
+ c.push_back(42);
+ }
+ };
+ vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end());
+ EXPECT_EQ(c, std::vector<int>({}));
+}
+
+TEST(VisitRangesTest, simple_merge_can_be_implemented) {
+ std::vector<int> a({1,3,7});
+ std::vector<int> b({2,3,8});
+ std::vector<int> c;
+ auto visitor = overload
+ {
+ [&c](visit_ranges_either, int x) {
+ c.push_back(x);
+ },
+ [&c](visit_ranges_both, int x, int y) {
+ c.push_back(x);
+ c.push_back(y);
+ }
+ };
+ vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end());
+ EXPECT_EQ(c, std::vector<int>({1,2,3,3,7,8}));
+}
+
+TEST(VisitRangesTest, simple_union_can_be_implemented) {
+ std::vector<int> a({1,3,7});
+ std::vector<int> b({2,3,8});
+ std::vector<int> c;
+ auto visitor = overload
+ {
+ [&c](visit_ranges_either, int x) {
+ c.push_back(x);
+ },
+ [&c](visit_ranges_both, int x, int) {
+ c.push_back(x);
+ }
+ };
+ vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end());
+ EXPECT_EQ(c, std::vector<int>({1,2,3,7,8}));
+}
+
+TEST(VisitRangesTest, asymmetric_merge_can_be_implemented) {
+ std::vector<int> a({1,3,7});
+ std::vector<int> b({2,3,8});
+ std::vector<int> c;
+ auto visitor = overload
+ {
+ [&c](visit_ranges_first, int x) {
+ c.push_back(x);
+ },
+ [&c](visit_ranges_second, int) {},
+ [&c](visit_ranges_both, int x, int y) {
+ c.push_back(x * y);
+ }
+ };
+ vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end());
+ EXPECT_EQ(c, std::vector<int>({1,9,7}));
+}
+
+TEST(VisitRangesTest, comparator_can_be_specified) {
+ std::vector<int> a({7,3,1});
+ std::vector<int> b({8,3,2});
+ std::vector<int> c;
+ auto visitor = overload
+ {
+ [&c](visit_ranges_either, int x) {
+ c.push_back(x);
+ },
+ [&c](visit_ranges_both, int x, int y) {
+ c.push_back(x);
+ c.push_back(y);
+ }
+ };
+ vespalib::visit_ranges(visitor, a.begin(), a.end(), b.begin(), b.end(), std::greater<>());
+ EXPECT_EQ(c, std::vector<int>({8,7,3,3,2,1}));
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/vespa/vespalib/util/overload.h b/vespalib/src/vespa/vespalib/util/overload.h
new file mode 100644
index 00000000000..d5af9dee2d3
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/util/overload.h
@@ -0,0 +1,15 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace vespalib {
+
+/**
+ * Simple overload lambda composition class. To be replaced by
+ * standard overload functionality when available. (C++20)
+ **/
+
+template<class... Ts> struct overload : Ts... { using Ts::operator()...; };
+template<class... Ts> overload(Ts...) -> overload<Ts...>;
+
+}
diff --git a/vespalib/src/vespa/vespalib/util/visit_ranges.h b/vespalib/src/vespa/vespalib/util/visit_ranges.h
new file mode 100644
index 00000000000..9cf1112fe0d
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/util/visit_ranges.h
@@ -0,0 +1,81 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <functional>
+
+namespace vespalib {
+
+struct visit_ranges_either {};
+struct visit_ranges_first : visit_ranges_either {};
+struct visit_ranges_second : visit_ranges_either {};
+struct visit_ranges_both {};
+
+/**
+ * Visit elements from two distinct ranges in the order defined by the
+ * given comparator. The comparator must define a strict-weak ordering
+ * across all elements from both ranges and each range must already be
+ * sorted according to the comparator before calling this
+ * function. Pairs of elements from the two ranges (one from each)
+ * that are equal according to the comparator will be visited by a
+ * single callback. The different cases ('from the first range', 'from
+ * the second range' and 'from both ranges') are indicated by using
+ * tagged dispatch in the visitation callback.
+ *
+ * An example treating both inputs equally:
+ * <pre>
+ * TEST(VisitRangeExample, set_intersection) {
+ * std::vector<int> first({1,3,7});
+ * std::vector<int> second({2,3,8});
+ * std::vector<int> result;
+ * vespalib::visit_ranges(overload{[](visit_ranges_either, int) {},
+ * [&result](visit_ranges_both, int x, int) { result.push_back(x); }},
+ * first.begin(), first.end(), second.begin(), second.end());
+ * EXPECT_EQ(result, std::vector<int>({3}));
+ * }
+ * </pre>
+ *
+ * An example treating the inputs differently:
+ * <pre>
+ * TEST(VisitRangeExample, set_subtraction) {
+ * std::vector<int> first({1,3,7});
+ * std::vector<int> second({2,3,8});
+ * std::vector<int> result;
+ * vespalib::visit_ranges(overload{[&result](visit_ranges_first, int a) { result.push_back(a); },
+ * [](visit_ranges_second, int) {},
+ * [](visit_ranges_both, int, int) {}},
+ * first.begin(), first.end(), second.begin(), second.end());
+ * EXPECT_EQ(result, std::vector<int>({1,7}));
+ * }
+ * </pre>
+ *
+ * The intention of this function is to simplify the implementation of
+ * merge-like operations.
+ **/
+
+template <typename V, typename ItA, typename ItB, typename Cmp = std::less<> >
+void visit_ranges(V &&visitor, ItA pos_a, ItA end_a, ItB pos_b, ItB end_b, Cmp cmp = Cmp()) {
+ while ((pos_a != end_a) && (pos_b != end_b)) {
+ if (cmp(*pos_a, *pos_b)) {
+ visitor(visit_ranges_first(), *pos_a);
+ ++pos_a;
+ } else if (cmp(*pos_b, *pos_a)) {
+ visitor(visit_ranges_second(), *pos_b);
+ ++pos_b;
+ } else {
+ visitor(visit_ranges_both(), *pos_a, *pos_b);
+ ++pos_a;
+ ++pos_b;
+ }
+ }
+ while (pos_a != end_a) {
+ visitor(visit_ranges_first(), *pos_a);
+ ++pos_a;
+ }
+ while (pos_b != end_b) {
+ visitor(visit_ranges_second(), *pos_b);
+ ++pos_b;
+ }
+}
+
+}