diff options
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp | 76 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h | 9 |
2 files changed, 36 insertions, 49 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp index 2e37406664c..53a5fe9bb27 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp @@ -17,57 +17,43 @@ using namespace eval::tensor_function; namespace { -ArrayRef<double> getMutableCells(const eval::Value &value) { - const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value); - return unconstify(denseTensor.cellsRef()); -} - -ConstArrayRef<double> getConstCells(const eval::Value &value) { +CellsRef getCellsRef(const eval::Value &value) { const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value); return denseTensor.cellsRef(); } -void my_inplace_left_join_op(eval::InterpretedFunction::State &state, uint64_t param) { - const Value &lhs = state.peek(1); - const Value &rhs = state.peek(0); +template <bool write_left> +void my_inplace_join_op(eval::InterpretedFunction::State &state, uint64_t param) { join_fun_t function = (join_fun_t)param; - ArrayRef<double> left_cells = getMutableCells(lhs); - ConstArrayRef<double> right_cells = getConstCells(rhs); - auto rhs_iter = right_cells.cbegin(); - for (double &cell: left_cells) { - cell = function(cell, *rhs_iter); - ++rhs_iter; + CellsRef lhs_cells = getCellsRef(state.peek(1)); + CellsRef rhs_cells = getCellsRef(state.peek(0)); + auto dst_cells = unconstify(write_left ? lhs_cells : rhs_cells); + for (size_t i = 0; i < dst_cells.size(); ++i) { + dst_cells[i] = function(lhs_cells[i], rhs_cells[i]); } - assert(rhs_iter == right_cells.cend()); - state.pop_pop_push(lhs); -} - -void my_inplace_right_join_op(eval::InterpretedFunction::State &state, uint64_t param) { - const Value &lhs = state.peek(1); - const Value &rhs = state.peek(0); - join_fun_t function = (join_fun_t)param; - ConstArrayRef<double> left_cells = getConstCells(lhs); - ArrayRef<double> right_cells = getMutableCells(rhs); - auto lhs_iter = left_cells.cbegin(); - for (double &cell: right_cells) { - cell = function(*lhs_iter, cell); - ++lhs_iter; + if (write_left) { + state.stack.pop_back(); + } else { + const Value &result = state.stack.back(); + state.pop_pop_push(result); } - assert(lhs_iter == left_cells.cend()); - state.pop_pop_push(rhs); } -bool isConcreteDenseTensor(const ValueType &type) { - return (type.is_dense() && !type.is_abstract()); +bool sameShapeConcreteDenseTensors(const ValueType &a, const ValueType &b) { + return (a.is_dense() && !a.is_abstract() && (a == b)); } } // namespace vespalib::tensor::<unnamed> -DenseInplaceJoinFunction::DenseInplaceJoinFunction(const eval::tensor_function::Join &orig, bool left_is_mutable) - : eval::tensor_function::Op2(orig.result_type(), orig.lhs(), orig.rhs()), - _function(orig.function()), - _left_is_mutable(left_is_mutable) +DenseInplaceJoinFunction::DenseInplaceJoinFunction(const ValueType &result_type, + const TensorFunction &lhs, + const TensorFunction &rhs, + join_fun_t function_in, + bool write_left_in) + : eval::tensor_function::Op2(result_type, lhs, rhs), + _function(function_in), + _write_left(write_left_in) { } @@ -78,11 +64,8 @@ DenseInplaceJoinFunction::~DenseInplaceJoinFunction() eval::InterpretedFunction::Instruction DenseInplaceJoinFunction::compile_self(Stash &) const { - if (_left_is_mutable) { - return eval::InterpretedFunction::Instruction(my_inplace_left_join_op, (uint64_t)_function); - } else { - return eval::InterpretedFunction::Instruction(my_inplace_right_join_op, (uint64_t)_function); - } + auto op = _write_left ? my_inplace_join_op<true> : my_inplace_join_op<false>; + return eval::InterpretedFunction::Instruction(op, (uint64_t)_function); } const TensorFunction & @@ -91,12 +74,11 @@ DenseInplaceJoinFunction::optimize(const eval::TensorFunction &expr, Stash &stas if (auto join = as<Join>(expr)) { const TensorFunction &lhs = join->lhs(); const TensorFunction &rhs = join->rhs(); - if ((lhs.result_is_mutable() || rhs.result_is_mutable()) - && join->result_type() == lhs.result_type() - && join->result_type() == rhs.result_type() - && isConcreteDenseTensor(join->result_type())) + if ((lhs.result_is_mutable() || rhs.result_is_mutable()) && + sameShapeConcreteDenseTensors(lhs.result_type(), rhs.result_type())) { - return stash.create<DenseInplaceJoinFunction>(*join, lhs.result_is_mutable()); + return stash.create<DenseInplaceJoinFunction>(join->result_type(), lhs, rhs, + join->function(), lhs.result_is_mutable()); } } return expr; diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h index 176191c995c..de2cdae3778 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h +++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h @@ -15,11 +15,16 @@ public: using join_fun_t = ::vespalib::eval::tensor_function::join_fun_t; private: join_fun_t _function; - bool _left_is_mutable; + bool _write_left; public: - DenseInplaceJoinFunction(const eval::tensor_function::Join &orig, bool left_is_mutable); + DenseInplaceJoinFunction(const eval::ValueType &result_type, + const TensorFunction &lhs, + const TensorFunction &rhs, + join_fun_t function_in, + bool write_left_in); ~DenseInplaceJoinFunction(); join_fun_t function() const { return _function; } + bool write_left() const { return _write_left; } bool result_is_mutable() const override { return true; } eval::InterpretedFunction::Instruction compile_self(Stash &stash) const override; static const eval::TensorFunction &optimize(const eval::TensorFunction &expr, Stash &stash); |