aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2018-01-09 23:05:15 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2018-01-09 23:05:15 +0100
commit0a8534171a39c00f7d1938c14eb4ec96a0c02692 (patch)
treec79788568b55d3caa6ab45febd447b8412dec457 /eval
parenta6a23d5e1d07aad3f0c88f7a70539fcf15fa1029 (diff)
Make and address context to keep code together.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp20
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h57
2 files changed, 47 insertions, 30 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
index 58f1d35dc48..55a5d90fd4b 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
@@ -10,11 +10,8 @@ DenseTensorAddressCombiner::~DenseTensorAddressCombiner() = default;
DenseTensorAddressCombiner::DenseTensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs,
CellsRef rhsCells)
- : _rightType(rhs),
+ : _rightAddress(rhs, rhsCells),
_combinedAddress(),
- _rightCells(rhsCells),
- _rightAddress(rhs.dimensions().size(), 0),
- _rightAccumulatedSize(_rightAddress.size()),
_left(),
_commonRight(),
_right()
@@ -41,13 +38,22 @@ DenseTensorAddressCombiner::DenseTensorAddressCombiner(const eval::ValueType &lh
++rhsItr;
}
_combinedAddress.resize(numDimensions);
+}
+
+DenseTensorAddressCombiner::AddressContext::AddressContext(const eval::ValueType &type, CellsRef cells)
+ : _type(type),
+ _cells(cells),
+ _address(type.dimensions().size(), 0),
+ _accumulatedSize(_address.size())
+{
size_t multiplier = 1;
- for (int32_t i(_rightAddress.size() - 1); i >= 0; i--) {
- _rightAccumulatedSize[i] = multiplier;
- multiplier *= _rightType.dimensions()[i].size;
+ for (int32_t i(_address.size() - 1); i >= 0; i--) {
+ _accumulatedSize[i] = multiplier;
+ multiplier *= type.dimensions()[i].size;
}
}
+
eval::ValueType
DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs)
{
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
index a45944ef9bf..5b7c2cb4cb4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.h
@@ -25,11 +25,29 @@ private:
using CellsRef = vespalib::ConstArrayRef<double>;
using size_type = eval::ValueType::Dimension::size_type;
- const eval::ValueType &_rightType;
+ class AddressContext {
+ public:
+ AddressContext(const eval::ValueType &type, CellsRef cells);
+ size_type dimSize(uint32_t dim) const { return _type.dimensions()[dim].size; }
+ double cell() const { return cell(index()); }
+ double cell(size_t cellIdx) const { return _cells[cellIdx]; }
+ size_t index() const {
+ size_t cellIdx(0);
+ for (uint32_t i(0); i < _address.size(); i++) {
+ cellIdx += _address[i]*_accumulatedSize[i];
+ }
+ return cellIdx;
+ }
+
+ const eval::ValueType &_type;
+ CellsRef _cells;
+ Address _address;
+ std::vector<size_t> _accumulatedSize;
+ };
+
+ AddressContext _rightAddress;
Address _combinedAddress;
- CellsRef _rightCells;
- Address _rightAddress;
- std::vector<size_t> _rightAccumulatedSize;
+
Mapping _left;
Mapping _commonRight;
Mapping _right;
@@ -38,14 +56,7 @@ private:
_combinedAddress[m.first] = addr[m.second];
}
}
- double rightCell(size_t cellIdx) const { return _rightCells[cellIdx]; }
- size_t rightIndex(const Address &address) const {
- size_t cellIdx(0);
- for (uint32_t i(0); i < address.size(); i++) {
- cellIdx += address[i]*_rightAccumulatedSize[i];
- }
- return cellIdx;
- }
+
public:
DenseTensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs, CellsRef rhsCells);
~DenseTensorAddressCombiner();
@@ -57,39 +68,39 @@ public:
bool updateCommonRight() {
for (const auto & m : _commonRight) {
- if (_combinedAddress[m.first] >= _rightType.dimensions()[m.second].size) {
+ if (_combinedAddress[m.first] >= _rightAddress._type.dimensions()[m.second].size) {
return false;
}
- _rightAddress[m.second] = _combinedAddress[m.first];
+ _rightAddress._address[m.second] = _combinedAddress[m.first];
}
return true;
}
- double rightCell() { return rightCell(rightIndex(_rightAddress)); }
+ double rightCell() { return _rightAddress.cell(); }
template <typename Func>
void for_each(Func && func) {
const int32_t lastDimension = _right.size() - 1;
int32_t curDimension = lastDimension;
- size_t rightCellIdx = rightIndex(_rightAddress);
+ size_t rightCellIdx = _rightAddress.index();
while (curDimension >= 0) {
const uint32_t rdim = _right[curDimension].second;
const uint32_t cdim = _right[curDimension].first;
size_type & cindex = _combinedAddress[cdim];
if (curDimension == lastDimension) {
- for (cindex = 0; cindex < _rightType.dimensions()[rdim].size; cindex++) {
- func(_combinedAddress, rightCell(rightCellIdx));
- rightCellIdx += _rightAccumulatedSize[rdim];
+ for (cindex = 0; cindex < _rightAddress.dimSize(rdim); cindex++) {
+ func(_combinedAddress, _rightAddress.cell(rightCellIdx));
+ rightCellIdx += _rightAddress._accumulatedSize[rdim];
}
cindex = 0;
- rightCellIdx -= _rightAccumulatedSize[rdim] * _rightType.dimensions()[rdim].size;
+ rightCellIdx -= _rightAddress._accumulatedSize[rdim] * _rightAddress.dimSize(rdim);
curDimension--;
} else {
- if (cindex < _rightType.dimensions()[rdim].size) {
+ if (cindex < _rightAddress.dimSize(rdim)) {
cindex++;
- rightCellIdx += _rightAccumulatedSize[rdim];
+ rightCellIdx += _rightAddress._accumulatedSize[rdim];
curDimension++;
} else {
- rightCellIdx -= _rightAccumulatedSize[rdim] * _rightType.dimensions()[rdim].size;
+ rightCellIdx -= _rightAddress._accumulatedSize[rdim] * _rightAddress.dimSize(rdim);
cindex = 0;
curDimension--;
}