diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2018-01-02 11:58:19 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2018-01-09 17:44:06 +0100 |
commit | 2646fc43074bf6ec8e72eff7b21bd0db59685961 (patch) | |
tree | 6e48bafe5e16f9fd9b0c425079d6f2b3e5677fed /eval | |
parent | 87a2ec1d0c645382f247e480df08e81e16a9943a (diff) |
Split in 2 methods to avoid if in inner loop.
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h | 9 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp | 35 |
2 files changed, 36 insertions, 8 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h index f56e46020fd..875ad79aa87 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h @@ -60,6 +60,8 @@ public: const Mapping & commonRight() const { return _commonRight; } + bool hasAnyRightOnlyDimensions() const { return ! _right.empty(); } + const Address &address() const { return _combinedAddress; } bool combine(const Address & lhs, const Address & rhs) { @@ -123,10 +125,6 @@ public: ~CommonDenseTensorCellsIterator(); template <typename Func> void for_each(Func && func) { - if (_mutable.empty()) { - func(_address, cell(index(_address))); - return; - } const int32_t lastDimension = _mutable.size() - 1; int32_t curDimension = lastDimension; size_t cellIdx = index(_address); @@ -161,6 +159,9 @@ public: } return true; } + double cell() const { + return cell(index(_address)); + } const eval::ValueType &fast_type() const { return _type; } }; diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp index d5982765fc7..315e2653432 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp @@ -10,11 +10,9 @@ namespace vespalib::tensor::dense { template <typename Function> std::unique_ptr<Tensor> -apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func) +apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder, + CommonDenseTensorCellsIterator & rhsIter, const DenseTensorView &lhs, Function &&func) { - DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type()); - CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef()); - DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type())); for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) { combiner.updateLeftAndCommon(lhsItr.address()); if (rhsIter.updateCommon(combiner.address())) { @@ -29,6 +27,35 @@ apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func) template <typename Function> std::unique_ptr<Tensor> +apply_no_rightonly_dimensions(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder, + CommonDenseTensorCellsIterator & rhsIter, + const DenseTensorView &lhs, Function &&func) +{ + for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) { + combiner.updateLeftAndCommon(lhsItr.address()); + if (rhsIter.updateCommon(combiner.address())) { + builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsIter.cell())); + } + } + return builder.build(); +} + +template <typename Function> +std::unique_ptr<Tensor> +apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func) +{ + DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type()); + DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type())); + CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef()); + if (combiner.hasAnyRightOnlyDimensions()) { + return apply(combiner, builder, rhsIter, lhs, std::move(func)); + } else { + return apply_no_rightonly_dimensions(combiner, builder, rhsIter, lhs, std::move(func)); + } +} + +template <typename Function> +std::unique_ptr<Tensor> apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func) { const DenseTensorView *view = dynamic_cast<const DenseTensorView *>(&rhs); |