summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-06-13 12:50:36 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-06-13 12:50:36 +0000
commit228e6e6dcca222f97036c66ec8f635a227a12404 (patch)
tree61b3611487facfe036878802f7a1caf51677952d /eval
parent9b2ebf004a4c57090bd4d55ed4c660443a7fa446 (diff)
add aggr typifier and use it
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/aggr.cpp14
-rw-r--r--eval/src/vespa/eval/eval/aggr.h19
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp30
3 files changed, 33 insertions, 30 deletions
diff --git a/eval/src/vespa/eval/eval/aggr.cpp b/eval/src/vespa/eval/eval/aggr.cpp
index d10bbc4abb8..8efb0ec9fe7 100644
--- a/eval/src/vespa/eval/eval/aggr.cpp
+++ b/eval/src/vespa/eval/eval/aggr.cpp
@@ -71,15 +71,11 @@ Aggregator::~Aggregator()
Aggregator &
Aggregator::create(Aggr aggr, Stash &stash)
{
- switch (aggr) {
- case Aggr::AVG: return stash.create<Wrapper<aggr::Avg<double>>>();
- case Aggr::COUNT: return stash.create<Wrapper<aggr::Count<double>>>();
- case Aggr::PROD: return stash.create<Wrapper<aggr::Prod<double>>>();
- case Aggr::SUM: return stash.create<Wrapper<aggr::Sum<double>>>();
- case Aggr::MAX: return stash.create<Wrapper<aggr::Max<double>>>();
- case Aggr::MIN: return stash.create<Wrapper<aggr::Min<double>>>();
- }
- LOG_ABORT("should not be reached");
+ return TypifyAggr::resolve(aggr, [&stash](auto t)->Aggregator&
+ {
+ using T = typename decltype(t)::template templ<double>;
+ return stash.create<Wrapper<T>>();
+ });
}
std::vector<Aggr>
diff --git a/eval/src/vespa/eval/eval/aggr.h b/eval/src/vespa/eval/eval/aggr.h
index 8dea54d8abc..e7431c2c23b 100644
--- a/eval/src/vespa/eval/eval/aggr.h
+++ b/eval/src/vespa/eval/eval/aggr.h
@@ -118,5 +118,24 @@ public:
};
} // namespave vespalib::eval::aggr
+
+struct TypifyAggr {
+ template <template<typename> typename A> struct Result {
+ static constexpr bool is_type = false;
+ template <typename T> using templ = A<T>;
+ };
+ template <typename F> static decltype(auto) resolve(Aggr aggr, F &&f) {
+ switch (aggr) {
+ case Aggr::AVG: return f(Result<aggr::Avg>());
+ case Aggr::COUNT: return f(Result<aggr::Count>());
+ case Aggr::PROD: return f(Result<aggr::Prod>());
+ case Aggr::SUM: return f(Result<aggr::Sum>());
+ case Aggr::MAX: return f(Result<aggr::Max>());
+ case Aggr::MIN: return f(Result<aggr::Min>());
+ }
+ abort();
+ }
+};
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp
index 663993b6c26..571bcb79c9f 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_single_reduce_function.cpp
@@ -2,6 +2,7 @@
#include "dense_single_reduce_function.h"
#include "dense_tensor_view.h"
+#include <vespa/vespalib/util/typify.h>
#include <vespa/eval/eval/value.h>
namespace vespalib::tensor {
@@ -12,6 +13,8 @@ using eval::TensorEngine;
using eval::TensorFunction;
using eval::Value;
using eval::ValueType;
+using eval::TypifyCellType;
+using eval::TypifyAggr;
using eval::as;
using namespace eval::tensor_function;
@@ -66,28 +69,13 @@ void my_single_reduce_op(InterpretedFunction::State &state, uint64_t param) {
state.pop_push(state.stash.create<DenseTensorView>(params.result_type, TypedCells(dst_cells)));
}
-template <typename CT>
-InterpretedFunction::op_function my_select_2(Aggr aggr) {
- switch (aggr) {
- case Aggr::AVG: return my_single_reduce_op<CT, Avg<CT>>;
- case Aggr::COUNT: return my_single_reduce_op<CT, Count<CT>>;
- case Aggr::PROD: return my_single_reduce_op<CT, Prod<CT>>;
- case Aggr::SUM: return my_single_reduce_op<CT, Sum<CT>>;
- case Aggr::MAX: return my_single_reduce_op<CT, Max<CT>>;
- case Aggr::MIN: return my_single_reduce_op<CT, Min<CT>>;
+struct MyGetFun {
+ template <typename R1, typename R2> static auto invoke() {
+ return my_single_reduce_op<R1, typename R2::template templ<R1>>;
}
- abort();
-}
+};
-InterpretedFunction::op_function my_select(CellType cell_type, Aggr aggr) {
- if (cell_type == ValueType::CellType::DOUBLE) {
- return my_select_2<double>(aggr);
- }
- if (cell_type == ValueType::CellType::FLOAT) {
- return my_select_2<float>(aggr);
- }
- abort();
-}
+using MyTypify = TypifyValue<TypifyCellType,TypifyAggr>;
bool check_input_type(const ValueType &type) {
return (type.is_dense() && ((type.cell_type() == CellType::FLOAT) || (type.cell_type() == CellType::DOUBLE)));
@@ -109,7 +97,7 @@ DenseSingleReduceFunction::~DenseSingleReduceFunction() = default;
InterpretedFunction::Instruction
DenseSingleReduceFunction::compile_self(const TensorEngine &, Stash &stash) const
{
- auto op = my_select(result_type().cell_type(), _aggr);
+ auto op = typify_invoke<2,MyTypify,MyGetFun>(result_type().cell_type(), _aggr);
auto &params = stash.create<Params>(result_type(), child().result_type(), _dim_idx);
static_assert(sizeof(uint64_t) == sizeof(&params));
return InterpretedFunction::Instruction(op, (uint64_t)&params);