diff options
author | Arne Juul <arnej@yahoo-inc.com> | 2019-06-24 12:36:02 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahoo-inc.com> | 2019-06-24 12:36:02 +0000 |
commit | 98ffc28bceaaaa47f1213283ed6bee1e9d1a2bac (patch) | |
tree | f357cc218498777b84b3d19fe6b822fc4e61229c /eval/src | |
parent | 52be221f7e73ed6c01464b7c7d160be28466dde5 (diff) |
use DenseDimensionCombiner class instead
Diffstat (limited to 'eval/src')
3 files changed, 24 insertions, 56 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp index 194f4d0eae3..80aef3c79b6 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp @@ -11,17 +11,16 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs) : _leftDims(), _rightDims(), _commonDims(), _leftIndex(0), _rightIndex(0), _outputIndex(0), - _leftOnlySize(1u), _rightOnlySize(1u), _outputSize(1u) + _leftOnlySize(1u), _rightOnlySize(1u), _outputSize(1u), + result_type(eval::ValueType::join(lhs, rhs)) { - eval::ValueType outputType = eval::ValueType::join(lhs, rhs); - assert(lhs.is_dense()); assert(rhs.is_dense()); - assert(outputType.is_dense()); + assert(result_type.is_dense()); const auto &lDims = lhs.dimensions(); const auto &rDims = rhs.dimensions(); - const auto &oDims = outputType.dimensions(); + const auto &oDims = result_type.dimensions(); size_t i = lDims.size(); size_t j = rDims.size(); diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h index 91449f122cd..6408c86902f 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h +++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h @@ -127,6 +127,8 @@ public: _outputIndex += _outputSize; } + const eval::ValueType result_type; + DenseDimensionCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs); ~DenseDimensionCombiner(); diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp index fa1e59c87db..e71840f392c 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp @@ -3,78 +3,45 @@ #pragma once #include "dense_tensor_apply.h" -#include "dense_tensor_address_combiner.h" +#include "dense_dimension_combiner.h" #include "direct_dense_tensor_builder.h" namespace vespalib::tensor::dense { template <typename Function> std::unique_ptr<Tensor> -apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder, - const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, Function &&func) __attribute__((noinline)); +apply(DenseDimensionCombiner & combiner, DirectDenseTensorBuilder & builder, + const DenseTensorView::CellsRef & lhsCells, + const DenseTensorView::CellsRef & rhsCells, Function &&func) __attribute__((noinline)); template <typename Function> std::unique_ptr<Tensor> -apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder, - const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, Function &&func) +apply(DenseDimensionCombiner & combiner, DirectDenseTensorBuilder & builder, + const DenseTensorView::CellsRef & lhsCells, + const DenseTensorView::CellsRef & rhsCells, Function &&func) { - for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) { - combiner.updateLeftAndCommon(lhsItr.address()); - if (combiner.updateCommon()) { - combiner.for_each_right(rhsCells, [&func, &builder, &lhsItr](size_t combined, double rhsCell) { - builder.insertCell(combined, func(lhsItr.cell(), rhsCell)); - }); + for (combiner.leftReset(); combiner.leftInRange(); combiner.stepLeft()) { + for (combiner.rightReset(); combiner.rightInRange(); combiner.stepRight()) { + for (combiner.commonReset(); combiner.commonInRange(); combiner.stepCommon()) { + size_t outIdx = combiner.outputIdx(); + size_t l = combiner.leftIdx(); + size_t r = combiner.rightIdx(); + builder.insertCell(outIdx, func(lhsCells[l], rhsCells[r])); + } } } return builder.build(); } - -template <typename Function> -std::unique_ptr<Tensor> -apply_no_rightonly_dimensions(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder, - const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, - Function &&func) __attribute__((noinline)); - -template <typename Function> -std::unique_ptr<Tensor> -apply_no_rightonly_dimensions(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder, - const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, Function &&func) -{ - for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) { - combiner.updateLeftAndCommon(lhsItr.address()); - if (combiner.updateCommon()) { - builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsCells[combiner.rightCellIndex()])); - } - } - return builder.build(); -} - -template <typename Function> -std::unique_ptr<Tensor> -apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func) -{ - eval::ValueType resultType = DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type()); - DenseTensorAddressCombiner combiner(resultType, lhs.fast_type(), rhs.fast_type()); - DirectDenseTensorBuilder builder(resultType); - if (combiner.hasAnyRightOnlyDimensions()) { - return apply(combiner, builder, lhs, rhs.cellsRef(), std::move(func)); - } else { - return apply_no_rightonly_dimensions(combiner, builder, lhs, rhs.cellsRef(), std::move(func)); - } -} - template <typename Function> std::unique_ptr<Tensor> apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func) { const DenseTensorView *view = dynamic_cast<const DenseTensorView *>(&rhs); if (view) { - return apply(lhs, *view, func); - } - const DenseTensor *dense = dynamic_cast<const DenseTensor *>(&rhs); - if (dense) { - return apply(lhs, *dense, func); + DenseDimensionCombiner combiner(lhs.fast_type(), view->fast_type()); + DirectDenseTensorBuilder builder(combiner.result_type); + return apply(combiner, builder, lhs.cellsRef(), view->cellsRef(), std::move(func)); } return Tensor::UP(); } |