diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-06-13 12:50:36 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-06-13 12:50:36 +0000 |
commit | 228e6e6dcca222f97036c66ec8f635a227a12404 (patch) | |
tree | 61b3611487facfe036878802f7a1caf51677952d /eval | |
parent | 9b2ebf004a4c57090bd4d55ed4c660443a7fa446 (diff) |
add aggr typifier and use it
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/aggr.cpp | 14 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/aggr.h | 19 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp | 30 |
3 files changed, 33 insertions, 30 deletions
diff --git a/eval/src/vespa/eval/eval/aggr.cpp b/eval/src/vespa/eval/eval/aggr.cpp index d10bbc4abb8..8efb0ec9fe7 100644 --- a/eval/src/vespa/eval/eval/aggr.cpp +++ b/eval/src/vespa/eval/eval/aggr.cpp @@ -71,15 +71,11 @@ Aggregator::~Aggregator() Aggregator & Aggregator::create(Aggr aggr, Stash &stash) { - switch (aggr) { - case Aggr::AVG: return stash.create<Wrapper<aggr::Avg<double>>>(); - case Aggr::COUNT: return stash.create<Wrapper<aggr::Count<double>>>(); - case Aggr::PROD: return stash.create<Wrapper<aggr::Prod<double>>>(); - case Aggr::SUM: return stash.create<Wrapper<aggr::Sum<double>>>(); - case Aggr::MAX: return stash.create<Wrapper<aggr::Max<double>>>(); - case Aggr::MIN: return stash.create<Wrapper<aggr::Min<double>>>(); - } - LOG_ABORT("should not be reached"); + return TypifyAggr::resolve(aggr, [&stash](auto t)->Aggregator& + { + using T = typename decltype(t)::template templ<double>; + return stash.create<Wrapper<T>>(); + }); } std::vector<Aggr> diff --git a/eval/src/vespa/eval/eval/aggr.h b/eval/src/vespa/eval/eval/aggr.h index 8dea54d8abc..e7431c2c23b 100644 --- a/eval/src/vespa/eval/eval/aggr.h +++ b/eval/src/vespa/eval/eval/aggr.h @@ -118,5 +118,24 @@ public: }; } // namespave vespalib::eval::aggr + +struct TypifyAggr { + template <template<typename> typename A> struct Result { + static constexpr bool is_type = false; + template <typename T> using templ = A<T>; + }; + template <typename F> static decltype(auto) resolve(Aggr aggr, F &&f) { + switch (aggr) { + case Aggr::AVG: return f(Result<aggr::Avg>()); + case Aggr::COUNT: return f(Result<aggr::Count>()); + case Aggr::PROD: return f(Result<aggr::Prod>()); + case Aggr::SUM: return f(Result<aggr::Sum>()); + case Aggr::MAX: return f(Result<aggr::Max>()); + case Aggr::MIN: return f(Result<aggr::Min>()); + } + abort(); + } +}; + } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp index 663993b6c26..571bcb79c9f 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp @@ -2,6 +2,7 @@ #include "dense_single_reduce_function.h" #include "dense_tensor_view.h" +#include <vespa/vespalib/util/typify.h> #include <vespa/eval/eval/value.h> namespace vespalib::tensor { @@ -12,6 +13,8 @@ using eval::TensorEngine; using eval::TensorFunction; using eval::Value; using eval::ValueType; +using eval::TypifyCellType; +using eval::TypifyAggr; using eval::as; using namespace eval::tensor_function; @@ -66,28 +69,13 @@ void my_single_reduce_op(InterpretedFunction::State &state, uint64_t param) { state.pop_push(state.stash.create<DenseTensorView>(params.result_type, TypedCells(dst_cells))); } -template <typename CT> -InterpretedFunction::op_function my_select_2(Aggr aggr) { - switch (aggr) { - case Aggr::AVG: return my_single_reduce_op<CT, Avg<CT>>; - case Aggr::COUNT: return my_single_reduce_op<CT, Count<CT>>; - case Aggr::PROD: return my_single_reduce_op<CT, Prod<CT>>; - case Aggr::SUM: return my_single_reduce_op<CT, Sum<CT>>; - case Aggr::MAX: return my_single_reduce_op<CT, Max<CT>>; - case Aggr::MIN: return my_single_reduce_op<CT, Min<CT>>; +struct MyGetFun { + template <typename R1, typename R2> static auto invoke() { + return my_single_reduce_op<R1, typename R2::template templ<R1>>; } - abort(); -} +}; -InterpretedFunction::op_function my_select(CellType cell_type, Aggr aggr) { - if (cell_type == ValueType::CellType::DOUBLE) { - return my_select_2<double>(aggr); - } - if (cell_type == ValueType::CellType::FLOAT) { - return my_select_2<float>(aggr); - } - abort(); -} +using MyTypify = TypifyValue<TypifyCellType,TypifyAggr>; bool check_input_type(const ValueType &type) { return (type.is_dense() && ((type.cell_type() == CellType::FLOAT) || (type.cell_type() == CellType::DOUBLE))); @@ -109,7 +97,7 @@ DenseSingleReduceFunction::~DenseSingleReduceFunction() = default; InterpretedFunction::Instruction DenseSingleReduceFunction::compile_self(const TensorEngine &, Stash &stash) const { - auto op = my_select(result_type().cell_type(), _aggr); + auto op = typify_invoke<2,MyTypify,MyGetFun>(result_type().cell_type(), _aggr); auto ¶ms = stash.create<Params>(result_type(), child().result_type(), _dim_idx); static_assert(sizeof(uint64_t) == sizeof(¶ms)); return InterpretedFunction::Instruction(op, (uint64_t)¶ms); |