summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@yahoo-inc.com>2019-06-25 09:34:36 +0000
committerArne Juul <arnej@yahoo-inc.com>2019-06-25 09:34:36 +0000
commit2e6a0ea56d399ba6966a39b0e11ef0b276c3e5af (patch)
tree530c280ec286eb05bfdc985ab0cf5498f94fd957 /eval
parent2a4f7a218cc47465745cafd03624898f7b46f574 (diff)
use common code for left and right step/reset
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp19
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h115
2 files changed, 56 insertions, 78 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp
index 72990306c39..22c8ff12ad1 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp
@@ -9,9 +9,10 @@ DenseDimensionCombiner::~DenseDimensionCombiner() = default;
DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs,
const eval::ValueType &rhs)
- : _leftDims(), _rightDims(), _commonDims(),
- _leftIndex(0), _rightIndex(0), _outputIndex(0),
- _leftOnlySize(1u), _rightOnlySize(1u), _outputSize(1u),
+ : _left(), _right(),
+ _commonDims(),
+ _outputIndex(0),
+ _outputSize(1u),
result_type(eval::ValueType::join(lhs, rhs))
{
assert(lhs.is_dense());
@@ -48,8 +49,8 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs,
lMul *= cd.size;
rMul *= cd.size;
oMul *= cd.size;
- _leftOnlySize *= cd.size;
- _rightOnlySize *= cd.size;
+ _left.totalSize *= cd.size;
+ _right.totalSize *= cd.size;
_outputSize *= cd.size;
_commonDims.push_back(cd);
} else {
@@ -61,9 +62,9 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs,
ld.size = oDims[k].size;
lMul *= ld.size;
oMul *= ld.size;
- _leftOnlySize *= ld.size;
_outputSize *= ld.size;
- _leftDims.push_back(ld);
+ _left.totalSize *= ld.size;
+ _left.dims.push_back(ld);
}
} else {
// right dim match
@@ -78,9 +79,9 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs,
rd.size = oDims[k].size;
rMul *= rd.size;
oMul *= rd.size;
- _rightOnlySize *= rd.size;
_outputSize *= rd.size;
- _rightDims.push_back(rd);
+ _right.totalSize *= rd.size;
+ _right.dims.push_back(rd);
}
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h
index 70896e307e1..dd3f74bad9b 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h
@@ -24,84 +24,61 @@ class DenseDimensionCombiner {
uint32_t outputMultiplier;
};
- std::vector<SideDim> _leftDims;
- std::vector<SideDim> _rightDims;
+ struct SideDims {
+ std::vector<SideDim> dims;
+ uint32_t index;
+ uint32_t totalSize;
+
+ SideDims() : dims(), index(0), totalSize(1u) {}
+
+ void reset(uint32_t &outIndex) {
+ for (SideDim& d : dims) {
+ index -= d.idx * d.sideMultiplier;
+ outIndex -= d.idx * d.outputMultiplier;
+ d.idx = 0;
+ }
+ if (index >= totalSize) {
+ index -= totalSize;
+ }
+ }
+ void step(uint32_t &outIndex) {
+ for (SideDim& d : dims) {
+ d.idx++;
+ index += d.sideMultiplier;
+ outIndex += d.outputMultiplier;
+ if (d.idx < d.size) return;
+ index -= d.idx * d.sideMultiplier;
+ outIndex -= d.idx * d.outputMultiplier;
+ d.idx = 0;
+ }
+ index += totalSize;
+ }
+ };
+ SideDims _left;
+ SideDims _right;
std::vector<CommonDim> _commonDims;
-
- uint32_t _leftIndex;
- uint32_t _rightIndex;
uint32_t _outputIndex;
-
- uint32_t _leftOnlySize;
- uint32_t _rightOnlySize;
uint32_t _outputSize;
public:
- size_t leftIdx() const { return _leftIndex; }
- size_t rightIdx() const { return _rightIndex; }
+ size_t leftIdx() const { return _left.index; }
+ size_t rightIdx() const { return _right.index; }
size_t outputIdx() const { return _outputIndex; }
- bool leftInRange() const { return _leftIndex < _leftOnlySize; }
- bool rightInRange() const { return _rightIndex < _rightOnlySize; }
+ bool leftInRange() const { return _left.index < _left.totalSize; }
+ bool rightInRange() const { return _right.index < _right.totalSize; }
bool commonInRange() const { return _outputIndex < _outputSize; }
- void leftReset() {
- for (SideDim& ld : _leftDims) {
- _leftIndex -= ld.idx * ld.sideMultiplier;
- _outputIndex -= ld.idx * ld.outputMultiplier;
- ld.idx = 0;
- }
- if (_leftIndex >= _leftOnlySize) {
- _leftIndex -= _leftOnlySize;
- }
- }
-
- void stepLeft() {
- size_t lim = _leftDims.size();
- for (size_t i = 0; i < lim; ++i) {
- SideDim& ld = _leftDims[i];
- ld.idx++;
- _leftIndex += ld.sideMultiplier;
- _outputIndex += ld.outputMultiplier;
- if (ld.idx < ld.size) return;
- _leftIndex -= ld.idx * ld.sideMultiplier;
- _outputIndex -= ld.idx * ld.outputMultiplier;
- ld.idx = 0;
- }
- _leftIndex += _leftOnlySize;
- }
-
-
- void rightReset() {
- for (SideDim& rd : _rightDims) {
- _rightIndex -= rd.idx * rd.sideMultiplier;
- _outputIndex -= rd.idx * rd.outputMultiplier;
- rd.idx = 0;
- }
- if (_rightIndex >= _rightOnlySize) {
- _rightIndex -= _rightOnlySize;
- }
- }
+ void leftReset() { _left.reset(_outputIndex); }
+ void stepLeft() { _left.step(_outputIndex); }
- void stepRight() {
- size_t lim = _rightDims.size();
- for (size_t i = 0; i < lim; ++i) {
- SideDim& rd = _rightDims[i];
- rd.idx++;
- _rightIndex += rd.sideMultiplier;
- _outputIndex += rd.outputMultiplier;
- if (rd.idx < rd.size) return;
- _rightIndex -= rd.idx * rd.sideMultiplier;
- _outputIndex -= rd.idx * rd.outputMultiplier;
- rd.idx = 0;
- }
- _rightIndex += _rightOnlySize;
- }
+ void rightReset() { _right.reset(_outputIndex); }
+ void stepRight() { _right.step(_outputIndex); }
void commonReset() {
for (CommonDim& cd : _commonDims) {
- _leftIndex -= cd.idx * cd.leftMultiplier;
- _rightIndex -= cd.idx * cd.rightMultiplier;
+ _left.index -= cd.idx * cd.leftMultiplier;
+ _right.index -= cd.idx * cd.rightMultiplier;
_outputIndex -= cd.idx * cd.outputMultiplier;
cd.idx = 0;
}
@@ -115,12 +92,12 @@ public:
for (size_t i = 0; i < lim; ++i) {
CommonDim &cd = _commonDims[i];
cd.idx++;
- _leftIndex += cd.leftMultiplier;
- _rightIndex += cd.rightMultiplier;
+ _left.index += cd.leftMultiplier;
+ _right.index += cd.rightMultiplier;
_outputIndex += cd.outputMultiplier;
if (cd.idx < cd.size) return;
- _leftIndex -= cd.idx * cd.leftMultiplier;
- _rightIndex -= cd.idx * cd.rightMultiplier;
+ _left.index -= cd.idx * cd.leftMultiplier;
+ _right.index -= cd.idx * cd.rightMultiplier;
_outputIndex -= cd.idx * cd.outputMultiplier;
cd.idx = 0;
}