aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-01-03 15:31:12 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-01-03 15:31:12 +0000
commit95a11020a168f9f068ac730f40eec0370571ca5a (patch)
treee23d798b336b025ae6cc93450ada77e98e21c4dc
parentdbe3a67718104c4150ae770294c23d8a41f0a16c (diff)
introduce overload class
and use it with std::visit when inspecting std::alternative
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp65
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h16
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp20
-rw-r--r--vespalib/CMakeLists.txt1
-rw-r--r--vespalib/src/tests/overload/CMakeLists.txt9
-rw-r--r--vespalib/src/tests/overload/overload_test.cpp19
-rw-r--r--vespalib/src/vespa/vespalib/util/overload.h15
7 files changed, 103 insertions, 42 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index e98a56a0148..9e429757982 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -164,20 +164,23 @@ void op_tensor_peek(State &state, uint64_t param) {
TensorSpec::Address addr;
size_t child_cnt = 0;
for (auto pos = self.spec().rbegin(); pos != self.spec().rend(); ++pos) {
- if (std::holds_alternative<TensorSpec::Label>(pos->second)) {
- addr.emplace(pos->first, std::get<TensorSpec::Label>(pos->second));
- } else {
- assert(std::holds_alternative<TensorFunction::Child>(pos->second));
- double index = round(state.peek(child_cnt++).as_double());
- size_t dim_idx = self.param_type().dimension_index(pos->first);
- assert(dim_idx != ValueType::Dimension::npos);
- const auto &param_dim = self.param_type().dimensions()[dim_idx];
- if (param_dim.is_mapped()) {
- addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index)));
- } else {
- addr.emplace(pos->first, size_t(index));
- }
- }
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ addr.emplace(pos->first, label);
+ },
+ [&](const TensorFunction::Child &) {
+ double index = round(state.peek(child_cnt++).as_double());
+ size_t dim_idx = self.param_type().dimension_index(pos->first);
+ assert(dim_idx != ValueType::Dimension::npos);
+ const auto &param_dim = self.param_type().dimensions()[dim_idx];
+ if (param_dim.is_mapped()) {
+ addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index)));
+ } else {
+ addr.emplace(pos->first, size_t(index));
+ }
+ }
+ }, pos->second);
}
TensorSpec spec = state.engine.to_spec(state.peek(child_cnt++));
const Value &result = self.result_type().is_double()
@@ -364,9 +367,13 @@ Peek::push_children(std::vector<Child::CREF> &children) const
{
children.emplace_back(_param);
for (const auto &dim: _spec) {
- if (std::holds_alternative<Child>(dim.second)) {
- children.emplace_back(std::get<Child>(dim.second));
- }
+ std::visit(vespalib::overload
+ {
+ [&](const Child &child) {
+ children.emplace_back(child);
+ },
+ [](const TensorSpec::Label &){}
+ }, dim.second);
}
}
@@ -381,17 +388,19 @@ Peek::visit_children(vespalib::ObjectVisitor &visitor) const
{
::visit(visitor, "param", _param.get());
for (const auto &dim: _spec) {
- if (std::holds_alternative<TensorSpec::Label>(dim.second)) {
- const auto &label = std::get<TensorSpec::Label>(dim.second);
- if (label.is_mapped()) {
- ::visit(visitor, dim.first, label.name);
- } else {
- ::visit(visitor, dim.first, label.index);
- }
- } else {
- assert(std::holds_alternative<Child>(dim.second));
- ::visit(visitor, dim.first, std::get<Child>(dim.second).get());
- }
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ if (label.is_mapped()) {
+ ::visit(visitor, dim.first, label.name);
+ } else {
+ ::visit(visitor, dim.first, label.index);
+ }
+ },
+ [&](const Child &child) {
+ ::visit(visitor, dim.first, child.get());
+ }
+ }, dim.second);
}
}
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index c95ffd17bbe..b019ab64e18 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -8,6 +8,7 @@
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/stllike/string.h>
#include <vespa/vespalib/util/arrayref.h>
+#include <vespa/vespalib/util/overload.h>
#include "tensor_spec.h"
#include "lazy_params.h"
#include "value_type.h"
@@ -320,12 +321,15 @@ public:
: Node(result_type_in), _param(param), _spec()
{
for (const auto &dim: spec) {
- if (std::holds_alternative<TensorSpec::Label>(dim.second)) {
- _spec.emplace(dim.first, std::get<TensorSpec::Label>(dim.second));
- } else {
- assert(std::holds_alternative<Node::CREF>(dim.second));
- _spec.emplace(dim.first, std::get<Node::CREF>(dim.second).get());
- }
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ _spec.emplace(dim.first, label);
+ },
+ [&](const Node::CREF &ref) {
+ _spec.emplace(dim.first, ref.get());
+ }
+ }, dim.second);
}
}
const std::map<vespalib::string, MyLabel> &spec() const { return _spec; }
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp
index f9d15a377e9..2920ca26234 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp
@@ -2,6 +2,7 @@
#include "dense_tensor_peek_function.h"
#include "dense_tensor_view.h"
+#include <vespa/vespalib/util/overload.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/tensor/tensor.h>
@@ -87,14 +88,17 @@ DenseTensorPeekFunction::optimize(const eval::TensorFunction &expr, Stash &stash
for (auto dim = peek_type.dimensions().rbegin(); dim != peek_type.dimensions().rend(); ++dim) {
auto dim_spec = peek->spec().find(dim->name);
assert(dim_spec != peek->spec().end());
- if (std::holds_alternative<TensorSpec::Label>(dim_spec->second)) {
- const auto &label = std::get<TensorSpec::Label>(dim_spec->second);
- assert(label.is_indexed());
- spec.emplace_back(label.index, dim->size);
- } else {
- assert(std::holds_alternative<TensorFunction::Child>(dim_spec->second));
- spec.emplace_back(-1, dim->size);
- }
+
+ std::visit(vespalib::overload
+ {
+ [&](const TensorSpec::Label &label) {
+ assert(label.is_indexed());
+ spec.emplace_back(label.index, dim->size);
+ },
+ [&](const TensorFunction::Child &) {
+ spec.emplace_back(-1, dim->size);
+ }
+ }, dim_spec->second);
}
return stash.create<DenseTensorPeekFunction>(peek->copy_children(), spec);
}
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt
index 14c14fe85ca..a76b8848b47 100644
--- a/vespalib/CMakeLists.txt
+++ b/vespalib/CMakeLists.txt
@@ -78,6 +78,7 @@ vespa_define_module(
src/tests/net/tls/transport_options
src/tests/objects/nbostream
src/tests/optimized
+ src/tests/overload
src/tests/portal
src/tests/portal/handle_manager
src/tests/portal/http_request
diff --git a/vespalib/src/tests/overload/CMakeLists.txt b/vespalib/src/tests/overload/CMakeLists.txt
new file mode 100644
index 00000000000..67aa6230225
--- /dev/null
+++ b/vespalib/src/tests/overload/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(vespalib_overload_test_app TEST
+ SOURCES
+ overload_test.cpp
+ DEPENDS
+ vespalib
+ gtest
+)
+vespa_add_test(NAME vespalib_overload_test_app COMMAND vespalib_overload_test_app)
diff --git a/vespalib/src/tests/overload/overload_test.cpp b/vespalib/src/tests/overload/overload_test.cpp
new file mode 100644
index 00000000000..ceae29ac02f
--- /dev/null
+++ b/vespalib/src/tests/overload/overload_test.cpp
@@ -0,0 +1,19 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/util/overload.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <variant>
+#include <string>
+
+using namespace vespalib;
+
+TEST(OverloadTest, visit_with_overload_works) {
+ std::variant<std::string,int> a = 10;
+ std::variant<std::string,int> b = "foo";
+ std::visit(overload{[](int v){ EXPECT_EQ(v,10); },
+ [](const std::string &){ FAIL() << "invalid visit"; }}, a);
+ std::visit(overload{[](int){ FAIL() << "invalid visit"; },
+ [](const std::string &v){ EXPECT_EQ(v, "foo"); }}, b);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/vespa/vespalib/util/overload.h b/vespalib/src/vespa/vespalib/util/overload.h
new file mode 100644
index 00000000000..d5af9dee2d3
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/util/overload.h
@@ -0,0 +1,15 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace vespalib {
+
+/**
+ * Simple overload lambda composition class. To be replaced by
+ * standard overload functionality when available. (C++20)
+ **/
+
+template<class... Ts> struct overload : Ts... { using Ts::operator()...; };
+template<class... Ts> overload(Ts...) -> overload<Ts...>;
+
+}