summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-11 14:03:26 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-11 14:03:26 +0000
commit39f623ffdef1d3b37c8d4cb1b039ff0416a1c29b (patch)
treeefdba9d8ab2503157e22b08959b86cb5fd091ce3
parent264390393fc68a51d0c2a23b15bccedd2eb6f55a (diff)
DenseSingleReduceFunction only handles stable cell types (for now)
-rw-r--r--eval/src/vespa/eval/instruction/dense_single_reduce_function.cpp22
1 files changed, 12 insertions, 10 deletions
diff --git a/eval/src/vespa/eval/instruction/dense_single_reduce_function.cpp b/eval/src/vespa/eval/instruction/dense_single_reduce_function.cpp
index f8cbca5ed46..492b6380ad2 100644
--- a/eval/src/vespa/eval/instruction/dense_single_reduce_function.cpp
+++ b/eval/src/vespa/eval/instruction/dense_single_reduce_function.cpp
@@ -119,9 +119,9 @@ void my_single_reduce_op(InterpretedFunction::State &state, uint64_t param) {
struct MyGetFun {
template <typename R1, typename R2, typename R3, typename R4> static auto invoke() {
- using CT = CellValueType<R1::value.cell_type>;
- using AggrType = typename R2::template templ<CT>;
- return my_single_reduce_op<CT, AggrType, R3::value, R4::value>;
+ using OCT = CellValueType<R1::value.decay().cell_type>;
+ using AggrType = typename R2::template templ<OCT>;
+ return my_single_reduce_op<OCT, AggrType, R3::value, R4::value>;
}
};
@@ -229,7 +229,7 @@ DenseSingleReduceFunction::~DenseSingleReduceFunction() = default;
InterpretedFunction::Instruction
DenseSingleReduceFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
- auto op = typify_invoke<4,MyTypify,MyGetFun>(result_type().cell_meta().limit().not_scalar(),
+ auto op = typify_invoke<4,MyTypify,MyGetFun>(child().result_type().cell_meta().limit().not_scalar(),
_aggr,
(_reduce_size >= 8), (_inner_size == 1));
auto &params = stash.create<Params>(result_type(), _outer_size, _reduce_size, _inner_size);
@@ -241,13 +241,15 @@ DenseSingleReduceFunction::optimize(const TensorFunction &expr, Stash &stash)
{
if (auto reduce = as<Reduce>(expr)) {
const auto &child = reduce->child();
- auto spec_list = make_dense_single_reduce_list(child.result_type(), reduce->aggr(), reduce->dimensions());
- if (!spec_list.empty()) {
- const auto *prev = &child;
- for (const auto &spec: spec_list) {
- prev = &stash.create<DenseSingleReduceFunction>(spec, *prev);
+ if (reduce->result_type().cell_meta().eq(child.result_type().cell_meta())) {
+ auto spec_list = make_dense_single_reduce_list(child.result_type(), reduce->aggr(), reduce->dimensions());
+ if (!spec_list.empty()) {
+ const auto *prev = &child;
+ for (const auto &spec: spec_list) {
+ prev = &stash.create<DenseSingleReduceFunction>(spec, *prev);
+ }
+ return *prev;
}
- return *prev;
}
}
return expr;