aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2022-06-07 10:01:15 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2022-06-07 10:01:15 +0000
commit05019f25b0d748dd34bbcfb2dc839611e09d96f0 (patch)
tree753ee21e96270fe03d10b8ab7e05d962c4d11798 /eval
parent38e71d4979792c42b0d163268ad1335cf3176b37 (diff)
full reduce with COUNT aggregator is cell count
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp2
-rw-r--r--eval/src/vespa/eval/instruction/generic_reduce.cpp12
2 files changed, 12 insertions, 2 deletions
diff --git a/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp b/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp
index fe531785278..d7298db4b68 100644
--- a/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp
+++ b/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp
@@ -75,7 +75,7 @@ void test_generic_reduce_with(const ValueBuilderFactory &factory) {
auto input = layout.cpy().cells(ct);
if (input.bad_scalar()) continue;
SCOPED_TRACE(fmt("tensor type: %s, num_cells: %zu", input.gen().type().c_str(), input.gen().cells().size()));
- for (Aggr aggr: {Aggr::SUM, Aggr::AVG, Aggr::MIN, Aggr::MAX}) {
+ for (Aggr aggr: {Aggr::SUM, Aggr::AVG, Aggr::MIN, Aggr::MAX, Aggr::COUNT}) {
SCOPED_TRACE(fmt("aggregator: %s", AggrNames::name_of(aggr)->c_str()));
auto t = layout.type();
for (const auto & dim: t.dimensions()) {
diff --git a/eval/src/vespa/eval/instruction/generic_reduce.cpp b/eval/src/vespa/eval/instruction/generic_reduce.cpp
index 71eb94f4118..ee74dd49fad 100644
--- a/eval/src/vespa/eval/instruction/generic_reduce.cpp
+++ b/eval/src/vespa/eval/instruction/generic_reduce.cpp
@@ -154,6 +154,12 @@ void my_generic_dense_reduce_op(State &state, uint64_t param_in) {
}
};
+template <typename ICT>
+void my_count_cells_op(State &state, uint64_t) {
+ auto cells = state.peek(0).cells().typify<ICT>();
+ state.pop_push(state.stash.create<DoubleValue>(cells.size()));
+};
+
template <typename ICT, typename AGGR>
void my_full_reduce_op(State &state, uint64_t) {
auto cells = state.peek(0).cells().typify<ICT>();
@@ -194,7 +200,11 @@ struct SelectGenericReduceOp {
using OCT = CellValueType<ICM::value.reduce(OIS::value).cell_type>;
using AggrType = typename AGGR::template templ<OCT>;
if constexpr (OIS::value) {
- return my_full_reduce_op<ICT, AggrType>;
+ if constexpr (AggrType::enum_value() == Aggr::COUNT) {
+ return my_count_cells_op<ICT>;
+ } else {
+ return my_full_reduce_op<ICT, AggrType>;
+ }
} else {
if (param.sparse_plan.should_forward_index()) {
return my_generic_dense_reduce_op<ICT, OCT, AggrType, true>;