diff options
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp | 19 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h | 115 |
2 files changed, 56 insertions, 78 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp index 72990306c39..22c8ff12ad1 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp @@ -9,9 +9,10 @@ DenseDimensionCombiner::~DenseDimensionCombiner() = default; DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs) - : _leftDims(), _rightDims(), _commonDims(), - _leftIndex(0), _rightIndex(0), _outputIndex(0), - _leftOnlySize(1u), _rightOnlySize(1u), _outputSize(1u), + : _left(), _right(), + _commonDims(), + _outputIndex(0), + _outputSize(1u), result_type(eval::ValueType::join(lhs, rhs)) { assert(lhs.is_dense()); @@ -48,8 +49,8 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs, lMul *= cd.size; rMul *= cd.size; oMul *= cd.size; - _leftOnlySize *= cd.size; - _rightOnlySize *= cd.size; + _left.totalSize *= cd.size; + _right.totalSize *= cd.size; _outputSize *= cd.size; _commonDims.push_back(cd); } else { @@ -61,9 +62,9 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs, ld.size = oDims[k].size; lMul *= ld.size; oMul *= ld.size; - _leftOnlySize *= ld.size; _outputSize *= ld.size; - _leftDims.push_back(ld); + _left.totalSize *= ld.size; + _left.dims.push_back(ld); } } else { // right dim match @@ -78,9 +79,9 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs, rd.size = oDims[k].size; rMul *= rd.size; oMul *= rd.size; - _rightOnlySize *= rd.size; _outputSize *= rd.size; - _rightDims.push_back(rd); + _right.totalSize *= rd.size; + _right.dims.push_back(rd); } } } diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h index 70896e307e1..dd3f74bad9b 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h +++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h @@ -24,84 +24,61 @@ class DenseDimensionCombiner { uint32_t outputMultiplier; }; - std::vector<SideDim> _leftDims; - std::vector<SideDim> _rightDims; + struct SideDims { + std::vector<SideDim> dims; + uint32_t index; + uint32_t totalSize; + + SideDims() : dims(), index(0), totalSize(1u) {} + + void reset(uint32_t &outIndex) { + for (SideDim& d : dims) { + index -= d.idx * d.sideMultiplier; + outIndex -= d.idx * d.outputMultiplier; + d.idx = 0; + } + if (index >= totalSize) { + index -= totalSize; + } + } + void step(uint32_t &outIndex) { + for (SideDim& d : dims) { + d.idx++; + index += d.sideMultiplier; + outIndex += d.outputMultiplier; + if (d.idx < d.size) return; + index -= d.idx * d.sideMultiplier; + outIndex -= d.idx * d.outputMultiplier; + d.idx = 0; + } + index += totalSize; + } + }; + SideDims _left; + SideDims _right; std::vector<CommonDim> _commonDims; - - uint32_t _leftIndex; - uint32_t _rightIndex; uint32_t _outputIndex; - - uint32_t _leftOnlySize; - uint32_t _rightOnlySize; uint32_t _outputSize; public: - size_t leftIdx() const { return _leftIndex; } - size_t rightIdx() const { return _rightIndex; } + size_t leftIdx() const { return _left.index; } + size_t rightIdx() const { return _right.index; } size_t outputIdx() const { return _outputIndex; } - bool leftInRange() const { return _leftIndex < _leftOnlySize; } - bool rightInRange() const { return _rightIndex < _rightOnlySize; } + bool leftInRange() const { return _left.index < _left.totalSize; } + bool rightInRange() const { return _right.index < _right.totalSize; } bool commonInRange() const { return _outputIndex < _outputSize; } - void leftReset() { - for (SideDim& ld : _leftDims) { - _leftIndex -= ld.idx * ld.sideMultiplier; - _outputIndex -= ld.idx * ld.outputMultiplier; - ld.idx = 0; - } - if (_leftIndex >= _leftOnlySize) { - _leftIndex -= _leftOnlySize; - } - } - - void stepLeft() { - size_t lim = _leftDims.size(); - for (size_t i = 0; i < lim; ++i) { - SideDim& ld = _leftDims[i]; - ld.idx++; - _leftIndex += ld.sideMultiplier; - _outputIndex += ld.outputMultiplier; - if (ld.idx < ld.size) return; - _leftIndex -= ld.idx * ld.sideMultiplier; - _outputIndex -= ld.idx * ld.outputMultiplier; - ld.idx = 0; - } - _leftIndex += _leftOnlySize; - } - - - void rightReset() { - for (SideDim& rd : _rightDims) { - _rightIndex -= rd.idx * rd.sideMultiplier; - _outputIndex -= rd.idx * rd.outputMultiplier; - rd.idx = 0; - } - if (_rightIndex >= _rightOnlySize) { - _rightIndex -= _rightOnlySize; - } - } + void leftReset() { _left.reset(_outputIndex); } + void stepLeft() { _left.step(_outputIndex); } - void stepRight() { - size_t lim = _rightDims.size(); - for (size_t i = 0; i < lim; ++i) { - SideDim& rd = _rightDims[i]; - rd.idx++; - _rightIndex += rd.sideMultiplier; - _outputIndex += rd.outputMultiplier; - if (rd.idx < rd.size) return; - _rightIndex -= rd.idx * rd.sideMultiplier; - _outputIndex -= rd.idx * rd.outputMultiplier; - rd.idx = 0; - } - _rightIndex += _rightOnlySize; - } + void rightReset() { _right.reset(_outputIndex); } + void stepRight() { _right.step(_outputIndex); } void commonReset() { for (CommonDim& cd : _commonDims) { - _leftIndex -= cd.idx * cd.leftMultiplier; - _rightIndex -= cd.idx * cd.rightMultiplier; + _left.index -= cd.idx * cd.leftMultiplier; + _right.index -= cd.idx * cd.rightMultiplier; _outputIndex -= cd.idx * cd.outputMultiplier; cd.idx = 0; } @@ -115,12 +92,12 @@ public: for (size_t i = 0; i < lim; ++i) { CommonDim &cd = _commonDims[i]; cd.idx++; - _leftIndex += cd.leftMultiplier; - _rightIndex += cd.rightMultiplier; + _left.index += cd.leftMultiplier; + _right.index += cd.rightMultiplier; _outputIndex += cd.outputMultiplier; if (cd.idx < cd.size) return; - _leftIndex -= cd.idx * cd.leftMultiplier; - _rightIndex -= cd.idx * cd.rightMultiplier; + _left.index -= cd.idx * cd.leftMultiplier; + _right.index -= cd.idx * cd.rightMultiplier; _outputIndex -= cd.idx * cd.outputMultiplier; cd.idx = 0; } |