diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2022-06-07 10:01:15 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2022-06-07 10:01:15 +0000 |
commit | 05019f25b0d748dd34bbcfb2dc839611e09d96f0 (patch) | |
tree | 753ee21e96270fe03d10b8ab7e05d962c4d11798 /eval | |
parent | 38e71d4979792c42b0d163268ad1335cf3176b37 (diff) |
full reduce with COUNT aggregator is cell count
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_reduce.cpp | 12 |
2 files changed, 12 insertions, 2 deletions
diff --git a/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp b/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp index fe531785278..d7298db4b68 100644 --- a/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp +++ b/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp @@ -75,7 +75,7 @@ void test_generic_reduce_with(const ValueBuilderFactory &factory) { auto input = layout.cpy().cells(ct); if (input.bad_scalar()) continue; SCOPED_TRACE(fmt("tensor type: %s, num_cells: %zu", input.gen().type().c_str(), input.gen().cells().size())); - for (Aggr aggr: {Aggr::SUM, Aggr::AVG, Aggr::MIN, Aggr::MAX}) { + for (Aggr aggr: {Aggr::SUM, Aggr::AVG, Aggr::MIN, Aggr::MAX, Aggr::COUNT}) { SCOPED_TRACE(fmt("aggregator: %s", AggrNames::name_of(aggr)->c_str())); auto t = layout.type(); for (const auto & dim: t.dimensions()) { diff --git a/eval/src/vespa/eval/instruction/generic_reduce.cpp b/eval/src/vespa/eval/instruction/generic_reduce.cpp index 71eb94f4118..ee74dd49fad 100644 --- a/eval/src/vespa/eval/instruction/generic_reduce.cpp +++ b/eval/src/vespa/eval/instruction/generic_reduce.cpp @@ -154,6 +154,12 @@ void my_generic_dense_reduce_op(State &state, uint64_t param_in) { } }; +template <typename ICT> +void my_count_cells_op(State &state, uint64_t) { + auto cells = state.peek(0).cells().typify<ICT>(); + state.pop_push(state.stash.create<DoubleValue>(cells.size())); +}; + template <typename ICT, typename AGGR> void my_full_reduce_op(State &state, uint64_t) { auto cells = state.peek(0).cells().typify<ICT>(); @@ -194,7 +200,11 @@ struct SelectGenericReduceOp { using OCT = CellValueType<ICM::value.reduce(OIS::value).cell_type>; using AggrType = typename AGGR::template templ<OCT>; if constexpr (OIS::value) { - return my_full_reduce_op<ICT, AggrType>; + if constexpr (AggrType::enum_value() == Aggr::COUNT) { + return my_count_cells_op<ICT>; + } else { + return my_full_reduce_op<ICT, AggrType>; + } } else { if (param.sparse_plan.should_forward_index()) { return my_generic_dense_reduce_op<ICT, OCT, AggrType, true>; |