summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2017-12-30 13:56:11 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2018-01-09 17:44:05 +0100
commit8964ab67ad73b53209988cc56e86207422d408ff (patch)
tree7593f2a1248dd5ae80e6f9ae86c399f57c934083 /eval
parentcf931a4d14aadd94a46659155cadb506966ee508 (diff)
Combine address in steps to reduce amount of work in inner loop.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp5
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h23
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp5
3 files changed, 28 insertions, 5 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
index ef2a56d4582..f9be59d7eb5 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
@@ -17,17 +17,22 @@ DenseTensorAddressCombiner::DenseTensorAddressCombiner(const eval::ValueType &lh
auto rhsItrEnd = rhs.dimensions().cend();
for (const auto &lhsDim : lhs.dimensions()) {
while ((rhsItr != rhsItrEnd) && (rhsItr->name < lhsDim.name)) {
+ _right.emplace_back(_ops.size(), rhsItr-rhs.dimensions().cbegin());
_ops.push_back(AddressOp::RHS);
++rhsItr;
}
if ((rhsItr != rhsItrEnd) && (rhsItr->name == lhsDim.name)) {
+ _left.emplace_back(_ops.size(), _left.size());
+ _commonRight.emplace_back(_ops.size(), rhsItr-rhs.dimensions().cbegin());
_ops.push_back(AddressOp::BOTH);
++rhsItr;
} else {
+ _left.emplace_back(_ops.size(), _left.size());
_ops.push_back(AddressOp::LHS);
}
}
while (rhsItr != rhsItrEnd) {
+ _right.emplace_back(_ops.size(), rhsItr-rhs.dimensions().cbegin());
_ops.push_back(AddressOp::RHS);
++rhsItr;
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
index 957d87e8f24..923025b5324 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
@@ -35,18 +35,35 @@ private:
Address::value_type nextLabel() { return _address[_idx++]; }
};
+ using Mapping = std::vector<std::pair<uint32_t, uint32_t>>;
std::vector<AddressOp> _ops;
Address _combinedAddress;
+ Mapping _left;
+ Mapping _commonRight;
+ Mapping _right;
+ void update(const Address & addr, const Mapping & mapping) {
+ for (const auto & m : mapping) {
+ _combinedAddress[m.first] = addr[m.second];
+ }
+ }
public:
DenseTensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs);
~DenseTensorAddressCombiner();
+ void updateLeftAndCommon(const Address & addr) { update(addr, _left); }
+ void updateRight(const Address & addr) { update(addr, _right); }
+ bool hasCommonWithRight(const Address & addr) const {
+ for (const auto & m : _commonRight) {
+ if (_combinedAddress[m.first] != addr[m.second]) return false;
+ }
+ return true;
+ }
const Address &address() const { return _combinedAddress; }
- bool combine(const CellsIterator &lhsItr, const CellsIterator &rhsItr) {
+ bool combine(const Address & lhs, const Address & rhs) {
uint32_t index(0);
- AddressReader lhsReader(lhsItr.address());
- AddressReader rhsReader(rhsItr.address());
+ AddressReader lhsReader(lhs);
+ AddressReader rhsReader(rhs);
for (const auto &op : _ops) {
switch (op) {
case AddressOp::LHS:
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index dc47d02d47c..999433713be 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -15,9 +15,10 @@ apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type());
DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type()));
for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
+ combiner.updateLeftAndCommon(lhsItr.address());
for (DenseTensorCellsIterator rhsItr = rhs.cellsIterator(); rhsItr.valid(); rhsItr.next()) {
- bool combineSuccess = combiner.combine(lhsItr, rhsItr);
- if (combineSuccess) {
+ if (combiner.hasCommonWithRight(rhsItr.address())) {
+ combiner.updateRight(rhsItr.address());
builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsItr.cell()));
}
}