summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2018-01-02 12:24:42 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2018-01-09 17:44:06 +0100
commit450cb5bfb064f6320f57f48e899208f5d13cccbf (patch)
tree347cd5731d20924865f965ab5b02ec14c4d55a7b /eval
parent2646fc43074bf6ec8e72eff7b21bd0db59685961 (diff)
Update the combined adress inline.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp9
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h39
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp7
3 files changed, 27 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
index 23d75412289..704fdc6d1ea 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
@@ -73,21 +73,16 @@ DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs,
CommonDenseTensorCellsIterator::CommonDenseTensorCellsIterator(const Mapping & common,
+ const Mapping & right,
const eval::ValueType &type_in,
CellsRef cells)
: _type(type_in),
_cells(cells),
_address(type_in.dimensions().size(), 0),
_common(common),
- _mutable(_address.size()),
+ _mutable(right),
_accumulatedSize(_address.size())
{
- for (uint32_t i(0); i < _address.size(); i++) {
- _mutable[i] = i;
- }
- for (auto cur = _common.rbegin(); cur != _common.rend(); cur++) {
- _mutable.erase(_mutable.begin() + cur->second);
- }
size_t multiplier = 1;
for (int32_t i(_address.size() - 1); i >= 0; i--) {
_accumulatedSize[i] = multiplier;
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
index 875ad79aa87..5039e1a2fbc 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
@@ -58,11 +58,13 @@ public:
return true;
}
- const Mapping & commonRight() const { return _commonRight; }
+ const Mapping & getCommonRight() const { return _commonRight; }
+ const Mapping & getRight() const { return _right; }
bool hasAnyRightOnlyDimensions() const { return ! _right.empty(); }
const Address &address() const { return _combinedAddress; }
+ Address &address() { return _combinedAddress; }
bool combine(const Address & lhs, const Address & rhs) {
uint32_t index(0);
@@ -108,8 +110,8 @@ private:
const eval::ValueType &_type;
CellsRef _cells;
Address _address;
- Mapping _common;
- std::vector<uint32_t> _mutable;
+ const Mapping &_common;
+ const Mapping &_mutable;
std::vector<size_t> _accumulatedSize;
double cell(size_t cellIdx) const { return _cells[cellIdx]; }
@@ -121,32 +123,35 @@ private:
return cellIdx;
}
public:
- CommonDenseTensorCellsIterator(const Mapping & common, const eval::ValueType &type_in, CellsRef cells);
+ CommonDenseTensorCellsIterator(const Mapping & common, const Mapping & right,
+ const eval::ValueType &type_in, CellsRef cells);
~CommonDenseTensorCellsIterator();
template <typename Func>
- void for_each(Func && func) {
+ void for_each(Address & combined, Func && func) {
const int32_t lastDimension = _mutable.size() - 1;
int32_t curDimension = lastDimension;
size_t cellIdx = index(_address);
while (curDimension >= 0) {
- const uint32_t dim = _mutable[curDimension];
- size_type & index = _address[dim];
+ const uint32_t rdim = _mutable[curDimension].second;
+ const uint32_t cdim = _mutable[curDimension].first;
+ size_type & rindex = _address[rdim];
+ size_type & cindex = combined[cdim];
if (curDimension == lastDimension) {
- for (index = 0; index < _type.dimensions()[dim].size; index++) {
- func(_address, cell(cellIdx));
- cellIdx += _accumulatedSize[dim];
+ for (rindex = 0, cindex = 0; rindex < _type.dimensions()[rdim].size; rindex++, cindex++) {
+ func(combined, cell(cellIdx));
+ cellIdx += _accumulatedSize[rdim];
}
- index = 0;
- cellIdx -= _accumulatedSize[dim] * _type.dimensions()[dim].size;
+ rindex = 0; cindex = 0;
+ cellIdx -= _accumulatedSize[rdim] * _type.dimensions()[rdim].size;
curDimension--;
} else {
- if (index < _type.dimensions()[dim].size) {
- index++;
- cellIdx += _accumulatedSize[dim];
+ if (rindex < _type.dimensions()[rdim].size) {
+ rindex++; cindex++;
+ cellIdx += _accumulatedSize[rdim];
curDimension++;
} else {
- cellIdx -= _accumulatedSize[dim] * _type.dimensions()[dim].size;
- index = 0;
+ cellIdx -= _accumulatedSize[rdim] * _type.dimensions()[rdim].size;
+ rindex = 0; cindex = 0;
curDimension--;
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index 315e2653432..062980a7296 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -16,9 +16,8 @@ apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
combiner.updateLeftAndCommon(lhsItr.address());
if (rhsIter.updateCommon(combiner.address())) {
- rhsIter.for_each([&combiner, &func, &builder, &lhsItr](const DenseTensorCellsIterator::Address & right, double rhsCell) {
- combiner.updateRight(right);
- builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsCell));
+ rhsIter.for_each(combiner.address(), [&combiner, &func, &builder, &lhsItr](const DenseTensorCellsIterator::Address & combined, double rhsCell) {
+ builder.insertCell(combined, func(lhsItr.cell(), rhsCell));
});
}
}
@@ -46,7 +45,7 @@ apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
{
DenseTensorAddressCombiner combiner(lhs.fast_type(), rhs.fast_type());
DirectDenseTensorBuilder builder(DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type()));
- CommonDenseTensorCellsIterator rhsIter(combiner.commonRight(), rhs.fast_type(), rhs.cellsRef());
+ CommonDenseTensorCellsIterator rhsIter(combiner.getCommonRight(), combiner.getRight(), rhs.fast_type(), rhs.cellsRef());
if (combiner.hasAnyRightOnlyDimensions()) {
return apply(combiner, builder, rhsIter, lhs, std::move(func));
} else {