diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-01-03 15:31:12 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-01-03 15:31:12 +0000 |
commit | 95a11020a168f9f068ac730f40eec0370571ca5a (patch) | |
tree | e23d798b336b025ae6cc93450ada77e98e21c4dc /eval | |
parent | dbe3a67718104c4150ae770294c23d8a41f0a16c (diff) |
introduce overload class
and use it with std::visit when inspecting std::alternative
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.cpp | 65 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.h | 16 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp | 20 |
3 files changed, 59 insertions, 42 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index e98a56a0148..9e429757982 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -164,20 +164,23 @@ void op_tensor_peek(State &state, uint64_t param) { TensorSpec::Address addr; size_t child_cnt = 0; for (auto pos = self.spec().rbegin(); pos != self.spec().rend(); ++pos) { - if (std::holds_alternative<TensorSpec::Label>(pos->second)) { - addr.emplace(pos->first, std::get<TensorSpec::Label>(pos->second)); - } else { - assert(std::holds_alternative<TensorFunction::Child>(pos->second)); - double index = round(state.peek(child_cnt++).as_double()); - size_t dim_idx = self.param_type().dimension_index(pos->first); - assert(dim_idx != ValueType::Dimension::npos); - const auto ¶m_dim = self.param_type().dimensions()[dim_idx]; - if (param_dim.is_mapped()) { - addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index))); - } else { - addr.emplace(pos->first, size_t(index)); - } - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + addr.emplace(pos->first, label); + }, + [&](const TensorFunction::Child &) { + double index = round(state.peek(child_cnt++).as_double()); + size_t dim_idx = self.param_type().dimension_index(pos->first); + assert(dim_idx != ValueType::Dimension::npos); + const auto ¶m_dim = self.param_type().dimensions()[dim_idx]; + if (param_dim.is_mapped()) { + addr.emplace(pos->first, vespalib::make_string("%ld", int64_t(index))); + } else { + addr.emplace(pos->first, size_t(index)); + } + } + }, pos->second); } TensorSpec spec = state.engine.to_spec(state.peek(child_cnt++)); const Value &result = self.result_type().is_double() @@ -364,9 +367,13 @@ Peek::push_children(std::vector<Child::CREF> &children) const { children.emplace_back(_param); for (const auto &dim: _spec) { - if (std::holds_alternative<Child>(dim.second)) { - children.emplace_back(std::get<Child>(dim.second)); - } + std::visit(vespalib::overload + { + [&](const Child &child) { + children.emplace_back(child); + }, + [](const TensorSpec::Label &){} + }, dim.second); } } @@ -381,17 +388,19 @@ Peek::visit_children(vespalib::ObjectVisitor &visitor) const { ::visit(visitor, "param", _param.get()); for (const auto &dim: _spec) { - if (std::holds_alternative<TensorSpec::Label>(dim.second)) { - const auto &label = std::get<TensorSpec::Label>(dim.second); - if (label.is_mapped()) { - ::visit(visitor, dim.first, label.name); - } else { - ::visit(visitor, dim.first, label.index); - } - } else { - assert(std::holds_alternative<Child>(dim.second)); - ::visit(visitor, dim.first, std::get<Child>(dim.second).get()); - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + if (label.is_mapped()) { + ::visit(visitor, dim.first, label.name); + } else { + ::visit(visitor, dim.first, label.index); + } + }, + [&](const Child &child) { + ::visit(visitor, dim.first, child.get()); + } + }, dim.second); } } diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index c95ffd17bbe..b019ab64e18 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -8,6 +8,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/stllike/string.h> #include <vespa/vespalib/util/arrayref.h> +#include <vespa/vespalib/util/overload.h> #include "tensor_spec.h" #include "lazy_params.h" #include "value_type.h" @@ -320,12 +321,15 @@ public: : Node(result_type_in), _param(param), _spec() { for (const auto &dim: spec) { - if (std::holds_alternative<TensorSpec::Label>(dim.second)) { - _spec.emplace(dim.first, std::get<TensorSpec::Label>(dim.second)); - } else { - assert(std::holds_alternative<Node::CREF>(dim.second)); - _spec.emplace(dim.first, std::get<Node::CREF>(dim.second).get()); - } + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + _spec.emplace(dim.first, label); + }, + [&](const Node::CREF &ref) { + _spec.emplace(dim.first, ref.get()); + } + }, dim.second); } } const std::map<vespalib::string, MyLabel> &spec() const { return _spec; } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp index f9d15a377e9..2920ca26234 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_peek_function.cpp @@ -2,6 +2,7 @@ #include "dense_tensor_peek_function.h" #include "dense_tensor_view.h" +#include <vespa/vespalib/util/overload.h> #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/tensor/tensor.h> @@ -87,14 +88,17 @@ DenseTensorPeekFunction::optimize(const eval::TensorFunction &expr, Stash &stash for (auto dim = peek_type.dimensions().rbegin(); dim != peek_type.dimensions().rend(); ++dim) { auto dim_spec = peek->spec().find(dim->name); assert(dim_spec != peek->spec().end()); - if (std::holds_alternative<TensorSpec::Label>(dim_spec->second)) { - const auto &label = std::get<TensorSpec::Label>(dim_spec->second); - assert(label.is_indexed()); - spec.emplace_back(label.index, dim->size); - } else { - assert(std::holds_alternative<TensorFunction::Child>(dim_spec->second)); - spec.emplace_back(-1, dim->size); - } + + std::visit(vespalib::overload + { + [&](const TensorSpec::Label &label) { + assert(label.is_indexed()); + spec.emplace_back(label.index, dim->size); + }, + [&](const TensorFunction::Child &) { + spec.emplace_back(-1, dim->size); + } + }, dim_spec->second); } return stash.create<DenseTensorPeekFunction>(peek->copy_children(), spec); } |