summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2018-01-02 02:36:38 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2018-01-09 17:44:05 +0100
commitea0c4d476aa7e11418eb52b253a780d7a3da7ed6 (patch)
tree06af6048e5039dbd67710b36e635e7bc668866d6 /eval
parent8964ab67ad73b53209988cc56e86207422d408ff (diff)
Iterate only the valid dimensions.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp25
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h79
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp11
3 files changed, 109 insertions, 6 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
index f9be59d7eb5..23d75412289 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
@@ -71,4 +71,29 @@ DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs,
eval::ValueType::tensor_type(std::move(result)));
}
+
+CommonDenseTensorCellsIterator::CommonDenseTensorCellsIterator(const Mapping & common,
+ const eval::ValueType &type_in,
+ CellsRef cells)
+ : _type(type_in),
+ _cells(cells),
+ _address(type_in.dimensions().size(), 0),
+ _common(common),
+ _mutable(_address.size()),
+ _accumulatedSize(_address.size())
+{
+ for (uint32_t i(0); i < _address.size(); i++) {
+ _mutable[i] = i;
+ }
+ for (auto cur = _common.rbegin(); cur != _common.rend(); cur++) {
+ _mutable.erase(_mutable.begin() + cur->second);
+ }
+ size_t multiplier = 1;
+ for (int32_t i(_address.size() - 1); i >= 0; i--) {
+ _accumulatedSize[i] = multiplier;
+ multiplier *= type_in.dimensions()[i].size;
+ }
+}
+CommonDenseTensorCellsIterator::~CommonDenseTensorCellsIterator() = default;
+
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
index 923025b5324..f56e46020fd 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
@@ -19,6 +19,7 @@ class DenseTensorAddressCombiner
{
public:
using Address = DenseTensorCellsIterator::Address;
+ using Mapping = std::vector<std::pair<uint32_t, uint32_t>>;
private:
enum class AddressOp { LHS, RHS, BOTH };
@@ -35,7 +36,6 @@ private:
Address::value_type nextLabel() { return _address[_idx++]; }
};
- using Mapping = std::vector<std::pair<uint32_t, uint32_t>>;
std::vector<AddressOp> _ops;
Address _combinedAddress;
Mapping _left;
@@ -58,6 +58,8 @@ public:
return true;
}
+ const Mapping & commonRight() const { return _commonRight; }
+
const Address &address() const { return _combinedAddress; }
bool combine(const Address & lhs, const Address & rhs) {
@@ -88,4 +90,79 @@ public:
static eval::ValueType combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs);
};
+
+/**
+ * Utility class to iterate over common cells in a dense tensor.
+ */
+class CommonDenseTensorCellsIterator
+{
+public:
+ using size_type = eval::ValueType::Dimension::size_type;
+ using Address = std::vector<size_type>;
+ using Mapping = DenseTensorAddressCombiner::Mapping;
+private:
+ using Dims = std::vector<uint32_t>;
+ using CellsRef = vespalib::ConstArrayRef<double>;
+ const eval::ValueType &_type;
+ CellsRef _cells;
+ Address _address;
+ Mapping _common;
+ std::vector<uint32_t> _mutable;
+ std::vector<size_t> _accumulatedSize;
+
+ double cell(size_t cellIdx) const { return _cells[cellIdx]; }
+ size_t index(const Address &address) const {
+ size_t cellIdx(0);
+ for (uint32_t i(0); i < address.size(); i++) {
+ cellIdx += address[i]*_accumulatedSize[i];
+ }
+ return cellIdx;
+ }
+public:
+ CommonDenseTensorCellsIterator(const Mapping & common, const eval::ValueType &type_in, CellsRef cells);
+ ~CommonDenseTensorCellsIterator();
+ template <typename Func>
+ void for_each(Func && func) {
+ if (_mutable.empty()) {
+ func(_address, cell(index(_address)));
+ return;
+ }
+ const int32_t lastDimension = _mutable.size() - 1;
+ int32_t curDimension = lastDimension;
+ size_t cellIdx = index(_address);
+ while (curDimension >= 0) {
+ const uint32_t dim = _mutable[curDimension];
+ size_type & index = _address[dim];
+ if (curDimension == lastDimension) {
+ for (index = 0; index < _type.dimensions()[dim].size; index++) {
+ func(_address, cell(cellIdx));
+ cellIdx += _accumulatedSize[dim];
+ }
+ index = 0;
+ cellIdx -= _accumulatedSize[dim] * _type.dimensions()[dim].size;
+ curDimension--;
+ } else {
+ if (index < _type.dimensions()[dim].size) {
+ index++;
+ cellIdx += _accumulatedSize[dim];
+ curDimension++;
+ } else {
+ cellIdx -= _accumulatedSize[dim] * _type.dimensions()[dim].size;
+ index = 0;
+ curDimension--;
+ }
+ }
+ }
+ }
+ bool updateCommon(const Address & combined) {
+ for (const auto & m : _common) {
+ if (combined[m.first] >= _type.dimensions()[m.second].size) return false;
+ _address[m.second] = combined[m.first];
+ }
+ return true;
+ }
+
+ const eval::ValueType &fast_type() const { return _type; }
+};
+
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index 999433713be..d5982765fc7 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -13,14 +13,15 @@ std::unique_ptr<Tensor>
apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
{
DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type());
+ CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef());
DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type()));
for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
combiner.updateLeftAndCommon(lhsItr.address());
- for (DenseTensorCellsIterator rhsItr = rhs.cellsIterator(); rhsItr.valid(); rhsItr.next()) {
- if (combiner.hasCommonWithRight(rhsItr.address())) {
- combiner.updateRight(rhsItr.address());
- builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsItr.cell()));
- }
+ if (rhsIter.updateCommon(combiner.address())) {
+ rhsIter.for_each([&combiner, &func, &builder, &lhsItr](const DenseTensorCellsIterator::Address & right, double rhsCell) {
+ combiner.updateRight(right);
+ builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsCell));
+ });
}
}
return builder.build();