diff options
Diffstat (limited to 'eval')
3 files changed, 28 insertions, 5 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp index ef2a56d4582..f9be59d7eb5 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp @@ -17,17 +17,22 @@ DenseTensorAddressCombiner::DenseTensorAddressCombiner(const eval::ValueType &lh auto rhsItrEnd = rhs.dimensions().cend(); for (const auto &lhsDim : lhs.dimensions()) { while ((rhsItr != rhsItrEnd) && (rhsItr->name < lhsDim.name)) { + _right.emplace_back(_ops.size(), rhsItr-rhs.dimensions().cbegin()); _ops.push_back(AddressOp::RHS); ++rhsItr; } if ((rhsItr != rhsItrEnd) && (rhsItr->name == lhsDim.name)) { + _left.emplace_back(_ops.size(), _left.size()); + _commonRight.emplace_back(_ops.size(), rhsItr-rhs.dimensions().cbegin()); _ops.push_back(AddressOp::BOTH); ++rhsItr; } else { + _left.emplace_back(_ops.size(), _left.size()); _ops.push_back(AddressOp::LHS); } } while (rhsItr != rhsItrEnd) { + _right.emplace_back(_ops.size(), rhsItr-rhs.dimensions().cbegin()); _ops.push_back(AddressOp::RHS); ++rhsItr; } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h index 957d87e8f24..923025b5324 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h @@ -35,18 +35,35 @@ private: Address::value_type nextLabel() { return _address[_idx++]; } }; + using Mapping = std::vector<std::pair<uint32_t, uint32_t>>; std::vector<AddressOp> _ops; Address _combinedAddress; + Mapping _left; + Mapping _commonRight; + Mapping _right; + void update(const Address & addr, const Mapping & mapping) { + for (const auto & m : mapping) { + _combinedAddress[m.first] = addr[m.second]; + } + } public: DenseTensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs); ~DenseTensorAddressCombiner(); + void updateLeftAndCommon(const Address & addr) { update(addr, _left); } + void updateRight(const Address & addr) { update(addr, _right); } + bool hasCommonWithRight(const Address & addr) const { + for (const auto & m : _commonRight) { + if (_combinedAddress[m.first] != addr[m.second]) return false; + } + return true; + } const Address &address() const { return _combinedAddress; } - bool combine(const CellsIterator &lhsItr, const CellsIterator &rhsItr) { + bool combine(const Address & lhs, const Address & rhs) { uint32_t index(0); - AddressReader lhsReader(lhsItr.address()); - AddressReader rhsReader(rhsItr.address()); + AddressReader lhsReader(lhs); + AddressReader rhsReader(rhs); for (const auto &op : _ops) { switch (op) { case AddressOp::LHS: diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp index dc47d02d47c..999433713be 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp @@ -15,9 +15,10 @@ apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func) DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type()); DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type())); for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) { + combiner.updateLeftAndCommon(lhsItr.address()); for (DenseTensorCellsIterator rhsItr = rhs.cellsIterator(); rhsItr.valid(); rhsItr.next()) { - bool combineSuccess = combiner.combine(lhsItr, rhsItr); - if (combineSuccess) { + if (combiner.hasCommonWithRight(rhsItr.address())) { + combiner.updateRight(rhsItr.address()); builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsItr.cell())); } } |