aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-03-02 12:39:21 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-03-02 12:39:21 +0000
commit39c3e0ac942996910a9325fb5589b3b1bf6df0bd (patch)
tree5bf8d65373eef9837c812992913baa6203149e81 /eval
parent829e3ddec1e49911f8ffcc5b0f32846efd236e9e (diff)
minor cleanup
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp76
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h9
2 files changed, 36 insertions, 49 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
index 2e37406664c..53a5fe9bb27 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
@@ -17,57 +17,43 @@ using namespace eval::tensor_function;
namespace {
-ArrayRef<double> getMutableCells(const eval::Value &value) {
- const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
- return unconstify(denseTensor.cellsRef());
-}
-
-ConstArrayRef<double> getConstCells(const eval::Value &value) {
+CellsRef getCellsRef(const eval::Value &value) {
const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
return denseTensor.cellsRef();
}
-void my_inplace_left_join_op(eval::InterpretedFunction::State &state, uint64_t param) {
- const Value &lhs = state.peek(1);
- const Value &rhs = state.peek(0);
+template <bool write_left>
+void my_inplace_join_op(eval::InterpretedFunction::State &state, uint64_t param) {
join_fun_t function = (join_fun_t)param;
- ArrayRef<double> left_cells = getMutableCells(lhs);
- ConstArrayRef<double> right_cells = getConstCells(rhs);
- auto rhs_iter = right_cells.cbegin();
- for (double &cell: left_cells) {
- cell = function(cell, *rhs_iter);
- ++rhs_iter;
+ CellsRef lhs_cells = getCellsRef(state.peek(1));
+ CellsRef rhs_cells = getCellsRef(state.peek(0));
+ auto dst_cells = unconstify(write_left ? lhs_cells : rhs_cells);
+ for (size_t i = 0; i < dst_cells.size(); ++i) {
+ dst_cells[i] = function(lhs_cells[i], rhs_cells[i]);
}
- assert(rhs_iter == right_cells.cend());
- state.pop_pop_push(lhs);
-}
-
-void my_inplace_right_join_op(eval::InterpretedFunction::State &state, uint64_t param) {
- const Value &lhs = state.peek(1);
- const Value &rhs = state.peek(0);
- join_fun_t function = (join_fun_t)param;
- ConstArrayRef<double> left_cells = getConstCells(lhs);
- ArrayRef<double> right_cells = getMutableCells(rhs);
- auto lhs_iter = left_cells.cbegin();
- for (double &cell: right_cells) {
- cell = function(*lhs_iter, cell);
- ++lhs_iter;
+ if (write_left) {
+ state.stack.pop_back();
+ } else {
+ const Value &result = state.stack.back();
+ state.pop_pop_push(result);
}
- assert(lhs_iter == left_cells.cend());
- state.pop_pop_push(rhs);
}
-bool isConcreteDenseTensor(const ValueType &type) {
- return (type.is_dense() && !type.is_abstract());
+bool sameShapeConcreteDenseTensors(const ValueType &a, const ValueType &b) {
+ return (a.is_dense() && !a.is_abstract() && (a == b));
}
} // namespace vespalib::tensor::<unnamed>
-DenseInplaceJoinFunction::DenseInplaceJoinFunction(const eval::tensor_function::Join &orig, bool left_is_mutable)
- : eval::tensor_function::Op2(orig.result_type(), orig.lhs(), orig.rhs()),
- _function(orig.function()),
- _left_is_mutable(left_is_mutable)
+DenseInplaceJoinFunction::DenseInplaceJoinFunction(const ValueType &result_type,
+ const TensorFunction &lhs,
+ const TensorFunction &rhs,
+ join_fun_t function_in,
+ bool write_left_in)
+ : eval::tensor_function::Op2(result_type, lhs, rhs),
+ _function(function_in),
+ _write_left(write_left_in)
{
}
@@ -78,11 +64,8 @@ DenseInplaceJoinFunction::~DenseInplaceJoinFunction()
eval::InterpretedFunction::Instruction
DenseInplaceJoinFunction::compile_self(Stash &) const
{
- if (_left_is_mutable) {
- return eval::InterpretedFunction::Instruction(my_inplace_left_join_op, (uint64_t)_function);
- } else {
- return eval::InterpretedFunction::Instruction(my_inplace_right_join_op, (uint64_t)_function);
- }
+ auto op = _write_left ? my_inplace_join_op<true> : my_inplace_join_op<false>;
+ return eval::InterpretedFunction::Instruction(op, (uint64_t)_function);
}
const TensorFunction &
@@ -91,12 +74,11 @@ DenseInplaceJoinFunction::optimize(const eval::TensorFunction &expr, Stash &stas
if (auto join = as<Join>(expr)) {
const TensorFunction &lhs = join->lhs();
const TensorFunction &rhs = join->rhs();
- if ((lhs.result_is_mutable() || rhs.result_is_mutable())
- && join->result_type() == lhs.result_type()
- && join->result_type() == rhs.result_type()
- && isConcreteDenseTensor(join->result_type()))
+ if ((lhs.result_is_mutable() || rhs.result_is_mutable()) &&
+ sameShapeConcreteDenseTensors(lhs.result_type(), rhs.result_type()))
{
- return stash.create<DenseInplaceJoinFunction>(*join, lhs.result_is_mutable());
+ return stash.create<DenseInplaceJoinFunction>(join->result_type(), lhs, rhs,
+ join->function(), lhs.result_is_mutable());
}
}
return expr;
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h
index 176191c995c..de2cdae3778 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.h
@@ -15,11 +15,16 @@ public:
using join_fun_t = ::vespalib::eval::tensor_function::join_fun_t;
private:
join_fun_t _function;
- bool _left_is_mutable;
+ bool _write_left;
public:
- DenseInplaceJoinFunction(const eval::tensor_function::Join &orig, bool left_is_mutable);
+ DenseInplaceJoinFunction(const eval::ValueType &result_type,
+ const TensorFunction &lhs,
+ const TensorFunction &rhs,
+ join_fun_t function_in,
+ bool write_left_in);
~DenseInplaceJoinFunction();
join_fun_t function() const { return _function; }
+ bool write_left() const { return _write_left; }
bool result_is_mutable() const override { return true; }
eval::InterpretedFunction::Instruction compile_self(Stash &stash) const override;
static const eval::TensorFunction &optimize(const eval::TensorFunction &expr, Stash &stash);