diff options
Diffstat (limited to 'eval')
3 files changed, 109 insertions, 6 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp index f9be59d7eb5..23d75412289 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp @@ -71,4 +71,29 @@ DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs, eval::ValueType::tensor_type(std::move(result))); } + +CommonDenseTensorCellsIterator::CommonDenseTensorCellsIterator(const Mapping & common, + const eval::ValueType &type_in, + CellsRef cells) + : _type(type_in), + _cells(cells), + _address(type_in.dimensions().size(), 0), + _common(common), + _mutable(_address.size()), + _accumulatedSize(_address.size()) +{ + for (uint32_t i(0); i < _address.size(); i++) { + _mutable[i] = i; + } + for (auto cur = _common.rbegin(); cur != _common.rend(); cur++) { + _mutable.erase(_mutable.begin() + cur->second); + } + size_t multiplier = 1; + for (int32_t i(_address.size() - 1); i >= 0; i--) { + _accumulatedSize[i] = multiplier; + multiplier *= type_in.dimensions()[i].size; + } +} +CommonDenseTensorCellsIterator::~CommonDenseTensorCellsIterator() = default; + } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h index 923025b5324..f56e46020fd 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h @@ -19,6 +19,7 @@ class DenseTensorAddressCombiner { public: using Address = DenseTensorCellsIterator::Address; + using Mapping = std::vector<std::pair<uint32_t, uint32_t>>; private: enum class AddressOp { LHS, RHS, BOTH }; @@ -35,7 +36,6 @@ private: Address::value_type nextLabel() { return _address[_idx++]; } }; - using Mapping = std::vector<std::pair<uint32_t, uint32_t>>; std::vector<AddressOp> _ops; Address _combinedAddress; Mapping _left; @@ -58,6 +58,8 @@ public: return true; } + const Mapping & commonRight() const { return _commonRight; } + const Address &address() const { return _combinedAddress; } bool combine(const Address & lhs, const Address & rhs) { @@ -88,4 +90,79 @@ public: static eval::ValueType combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs); }; + +/** + * Utility class to iterate over common cells in a dense tensor. + */ +class CommonDenseTensorCellsIterator +{ +public: + using size_type = eval::ValueType::Dimension::size_type; + using Address = std::vector<size_type>; + using Mapping = DenseTensorAddressCombiner::Mapping; +private: + using Dims = std::vector<uint32_t>; + using CellsRef = vespalib::ConstArrayRef<double>; + const eval::ValueType &_type; + CellsRef _cells; + Address _address; + Mapping _common; + std::vector<uint32_t> _mutable; + std::vector<size_t> _accumulatedSize; + + double cell(size_t cellIdx) const { return _cells[cellIdx]; } + size_t index(const Address &address) const { + size_t cellIdx(0); + for (uint32_t i(0); i < address.size(); i++) { + cellIdx += address[i]*_accumulatedSize[i]; + } + return cellIdx; + } +public: + CommonDenseTensorCellsIterator(const Mapping & common, const eval::ValueType &type_in, CellsRef cells); + ~CommonDenseTensorCellsIterator(); + template <typename Func> + void for_each(Func && func) { + if (_mutable.empty()) { + func(_address, cell(index(_address))); + return; + } + const int32_t lastDimension = _mutable.size() - 1; + int32_t curDimension = lastDimension; + size_t cellIdx = index(_address); + while (curDimension >= 0) { + const uint32_t dim = _mutable[curDimension]; + size_type & index = _address[dim]; + if (curDimension == lastDimension) { + for (index = 0; index < _type.dimensions()[dim].size; index++) { + func(_address, cell(cellIdx)); + cellIdx += _accumulatedSize[dim]; + } + index = 0; + cellIdx -= _accumulatedSize[dim] * _type.dimensions()[dim].size; + curDimension--; + } else { + if (index < _type.dimensions()[dim].size) { + index++; + cellIdx += _accumulatedSize[dim]; + curDimension++; + } else { + cellIdx -= _accumulatedSize[dim] * _type.dimensions()[dim].size; + index = 0; + curDimension--; + } + } + } + } + bool updateCommon(const Address & combined) { + for (const auto & m : _common) { + if (combined[m.first] >= _type.dimensions()[m.second].size) return false; + _address[m.second] = combined[m.first]; + } + return true; + } + + const eval::ValueType &fast_type() const { return _type; } +}; + } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp index 999433713be..d5982765fc7 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp @@ -13,14 +13,15 @@ std::unique_ptr<Tensor> apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func) { DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type()); + CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef()); DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type())); for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) { combiner.updateLeftAndCommon(lhsItr.address()); - for (DenseTensorCellsIterator rhsItr = rhs.cellsIterator(); rhsItr.valid(); rhsItr.next()) { - if (combiner.hasCommonWithRight(rhsItr.address())) { - combiner.updateRight(rhsItr.address()); - builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsItr.cell())); - } + if (rhsIter.updateCommon(combiner.address())) { + rhsIter.for_each([&combiner, &func, &builder, &lhsItr](const DenseTensorCellsIterator::Address & right, double rhsCell) { + combiner.updateRight(right); + builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsCell)); + }); } } return builder.build(); |