diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-01-03 15:31:12 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-01-03 15:31:12 +0000 |
commit | 95a11020a168f9f068ac730f40eec0370571ca5a (patch) | |
tree | e23d798b336b025ae6cc93450ada77e98e21c4dc | |
parent | dbe3a67718104c4150ae770294c23d8a41f0a16c (diff) |
introduce overload class
and use it with std::visit when inspecting std::alternative
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.cpp | 65 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.h | 16 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp | 20 | ||||
-rw-r--r-- | vespalib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | vespalib/src/tests/overload/CMakeLists.txt | 9 | ||||
-rw-r--r-- | vespalib/src/tests/overload/overload_test.cpp | 19 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/util/overload.h | 15 |
7 files changed, 103 insertions, 42 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index e98a56a0148..9e429757982 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -164,20 +164,23 @@ void op_tensor_peek(State &state, uint64_t param) { TensorSpec::Address addr; size_t child_cnt = 0; for (auto pos = self.spec().rbegin(); pos != self.spec().rend(); ++pos) { - if (std::holds_alternative<TensorSpec::Label>(pos->second)) { - addr.emplace(pos->first, std::get<TensorSpec::Label>(pos->second)); - } else { - assert(std::holds_alternative<TensorFunction::Child>(pos->second)); - double index = round(state.peek(child_cnt++).as_double()); - size_t dim_idx = self.param_type().dimension_index(pos->first); - assert(dim_idx != ValueType::Dimension::npos); - const auto ¶m_dim = self.param_type().dimensions()[dim_idx]; - if (param_dim.is_mapped()) { - addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index))); - } else { - addr.emplace(pos->first, size_t(index)); - } - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + addr.emplace(pos->first, label); + }, + [&](const TensorFunction::Child &) { + double index = round(state.peek(child_cnt++).as_double()); + size_t dim_idx = self.param_type().dimension_index(pos->first); + assert(dim_idx != ValueType::Dimension::npos); + const auto ¶m_dim = self.param_type().dimensions()[dim_idx]; + if (param_dim.is_mapped()) { + addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index))); + } else { + addr.emplace(pos->first, size_t(index)); + } + } + }, pos->second); } TensorSpec spec = state.engine.to_spec(state.peek(child_cnt++)); const Value &result = self.result_type().is_double() @@ -364,9 +367,13 @@ Peek::push_children(std::vector<Child::CREF> &children) const { children.emplace_back(_param); for (const auto &dim: _spec) { - if (std::holds_alternative<Child>(dim.second)) { - children.emplace_back(std::get<Child>(dim.second)); - } + std::visit(vespalib::overload + { + [&](const Child &child) { + children.emplace_back(child); + }, + [](const TensorSpec::Label &){} + }, dim.second); } } @@ -381,17 +388,19 @@ Peek::visit_children(vespalib::ObjectVisitor &visitor) const { ::visit(visitor, "param", _param.get()); for (const auto &dim: _spec) { - if (std::holds_alternative<TensorSpec::Label>(dim.second)) { - const auto &label = std::get<TensorSpec::Label>(dim.second); - if (label.is_mapped()) { - ::visit(visitor, dim.first, label.name); - } else { - ::visit(visitor, dim.first, label.index); - } - } else { - assert(std::holds_alternative<Child>(dim.second)); - ::visit(visitor, dim.first, std::get<Child>(dim.second).get()); - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + if (label.is_mapped()) { + ::visit(visitor, dim.first, label.name); + } else { + ::visit(visitor, dim.first, label.index); + } + }, + [&](const Child &child) { + ::visit(visitor, dim.first, child.get()); + } + }, dim.second); } } diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index c95ffd17bbe..b019ab64e18 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -8,6 +8,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/stllike/string.h> #include <vespa/vespalib/util/arrayref.h> +#include <vespa/vespalib/util/overload.h> #include "tensor_spec.h" #include "lazy_params.h" #include "value_type.h" @@ -320,12 +321,15 @@ public: : Node(result_type_in), _param(param), _spec() { for (const auto &dim: spec) { - if (std::holds_alternative<TensorSpec::Label>(dim.second)) { - _spec.emplace(dim.first, std::get<TensorSpec::Label>(dim.second)); - } else { - assert(std::holds_alternative<Node::CREF>(dim.second)); - _spec.emplace(dim.first, std::get<Node::CREF>(dim.second).get()); - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + _spec.emplace(dim.first, label); + }, + [&](const Node::CREF &ref) { + _spec.emplace(dim.first, ref.get()); + } + }, dim.second); } } const std::map<vespalib::string, MyLabel> &spec() const { return _spec; } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp index f9d15a377e9..2920ca26234 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp @@ -2,6 +2,7 @@ #include "dense_tensor_peek_function.h" #include "dense_tensor_view.h" +#include <vespa/vespalib/util/overload.h> #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/tensor/tensor.h> @@ -87,14 +88,17 @@ DenseTensorPeekFunction::optimize(const eval::TensorFunction &expr, Stash &stash for (auto dim = peek_type.dimensions().rbegin(); dim != peek_type.dimensions().rend(); ++dim) { auto dim_spec = peek->spec().find(dim->name); assert(dim_spec != peek->spec().end()); - if (std::holds_alternative<TensorSpec::Label>(dim_spec->second)) { - const auto &label = std::get<TensorSpec::Label>(dim_spec->second); - assert(label.is_indexed()); - spec.emplace_back(label.index, dim->size); - } else { - assert(std::holds_alternative<TensorFunction::Child>(dim_spec->second)); - spec.emplace_back(-1, dim->size); - } + + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + assert(label.is_indexed()); + spec.emplace_back(label.index, dim->size); + }, + [&](const TensorFunction::Child &) { + spec.emplace_back(-1, dim->size); + } + }, dim_spec->second); } return stash.create<DenseTensorPeekFunction>(peek->copy_children(), spec); } diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 14c14fe85ca..a76b8848b47 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -78,6 +78,7 @@ vespa_define_module( src/tests/net/tls/transport_options src/tests/objects/nbostream src/tests/optimized + src/tests/overload src/tests/portal src/tests/portal/handle_manager src/tests/portal/http_request diff --git a/vespalib/src/tests/overload/CMakeLists.txt b/vespalib/src/tests/overload/CMakeLists.txt new file mode 100644 index 00000000000..67aa6230225 --- /dev/null +++ b/vespalib/src/tests/overload/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_overload_test_app TEST + SOURCES + overload_test.cpp + DEPENDS + vespalib + gtest +) +vespa_add_test(NAME vespalib_overload_test_app COMMAND vespalib_overload_test_app) diff --git a/vespalib/src/tests/overload/overload_test.cpp b/vespalib/src/tests/overload/overload_test.cpp new file mode 100644 index 00000000000..ceae29ac02f --- /dev/null +++ b/vespalib/src/tests/overload/overload_test.cpp @@ -0,0 +1,19 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/overload.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <variant> +#include <string> + +using namespace vespalib; + +TEST(OverloadTest, visit_with_overload_works) { + std::variant<std::string,int> a = 10; + std::variant<std::string,int> b = "foo"; + std::visit(overload{[](int v){ EXPECT_EQ(v,10); }, + [](const std::string &){ FAIL() << "invalid visit"; }}, a); + std::visit(overload{[](int){ FAIL() << "invalid visit"; }, + [](const std::string &v){ EXPECT_EQ(v, "foo"); }}, b); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/util/overload.h b/vespalib/src/vespa/vespalib/util/overload.h new file mode 100644 index 00000000000..d5af9dee2d3 --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/overload.h @@ -0,0 +1,15 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib { + +/** + * Simple overload lambda composition class. To be replaced by + * standard overload functionality when available. (C++20) + **/ + +template<class... Ts> struct overload : Ts... { using Ts::operator()...; }; +template<class... Ts> overload(Ts...) -> overload<Ts...>; + +} |