summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2018-01-02 11:58:19 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2018-01-09 17:44:06 +0100
commit2646fc43074bf6ec8e72eff7b21bd0db59685961 (patch)
tree6e48bafe5e16f9fd9b0c425079d6f2b3e5677fed /eval
parent87a2ec1d0c645382f247e480df08e81e16a9943a (diff)
Split in 2 methods to avoid if in inner loop.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h9
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp35
2 files changed, 36 insertions, 8 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
index f56e46020fd..875ad79aa87 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
@@ -60,6 +60,8 @@ public:
const Mapping & commonRight() const { return _commonRight; }
+ bool hasAnyRightOnlyDimensions() const { return ! _right.empty(); }
+
const Address &address() const { return _combinedAddress; }
bool combine(const Address & lhs, const Address & rhs) {
@@ -123,10 +125,6 @@ public:
~CommonDenseTensorCellsIterator();
template <typename Func>
void for_each(Func && func) {
- if (_mutable.empty()) {
- func(_address, cell(index(_address)));
- return;
- }
const int32_t lastDimension = _mutable.size() - 1;
int32_t curDimension = lastDimension;
size_t cellIdx = index(_address);
@@ -161,6 +159,9 @@ public:
}
return true;
}
+ double cell() const {
+ return cell(index(_address));
+ }
const eval::ValueType &fast_type() const { return _type; }
};
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index d5982765fc7..315e2653432 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -10,11 +10,9 @@ namespace vespalib::tensor::dense {
template <typename Function>
std::unique_ptr<Tensor>
-apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
+apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
+ CommonDenseTensorCellsIterator & rhsIter, const DenseTensorView &lhs, Function &&func)
{
- DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type());
- CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef());
- DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type()));
for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
combiner.updateLeftAndCommon(lhsItr.address());
if (rhsIter.updateCommon(combiner.address())) {
@@ -29,6 +27,35 @@ apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
template <typename Function>
std::unique_ptr<Tensor>
+apply_no_rightonly_dimensions(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
+ CommonDenseTensorCellsIterator & rhsIter,
+ const DenseTensorView &lhs, Function &&func)
+{
+ for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
+ combiner.updateLeftAndCommon(lhsItr.address());
+ if (rhsIter.updateCommon(combiner.address())) {
+ builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsIter.cell()));
+ }
+ }
+ return builder.build();
+}
+
+template <typename Function>
+std::unique_ptr<Tensor>
+apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
+{
+ DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type());
+ DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type()));
+ CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef());
+ if (combiner.hasAnyRightOnlyDimensions()) {
+ return apply(combiner, builder, rhsIter, lhs, std::move(func));
+ } else {
+ return apply_no_rightonly_dimensions(combiner, builder, rhsIter, lhs, std::move(func));
+ }
+}
+
+template <typename Function>
+std::unique_ptr<Tensor>
apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func)
{
const DenseTensorView *view = dynamic_cast<const DenseTensorView *>(&rhs);