summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@yahoo-inc.com>2019-06-24 12:36:02 +0000
committerArne Juul <arnej@yahoo-inc.com>2019-06-24 12:36:02 +0000
commit98ffc28bceaaaa47f1213283ed6bee1e9d1a2bac (patch)
treef357cc218498777b84b3d19fe6b822fc4e61229c /eval
parent52be221f7e73ed6c01464b7c7d160be28466dde5 (diff)
use DenseDimensionCombiner class instead
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp9
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h2
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp69
3 files changed, 24 insertions, 56 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp
index 194f4d0eae3..80aef3c79b6 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.cpp
@@ -11,17 +11,16 @@ DenseDimensionCombiner::DenseDimensionCombiner(const eval::ValueType &lhs,
const eval::ValueType &rhs)
: _leftDims(), _rightDims(), _commonDims(),
_leftIndex(0), _rightIndex(0), _outputIndex(0),
- _leftOnlySize(1u), _rightOnlySize(1u), _outputSize(1u)
+ _leftOnlySize(1u), _rightOnlySize(1u), _outputSize(1u),
+ result_type(eval::ValueType::join(lhs, rhs))
{
- eval::ValueType outputType = eval::ValueType::join(lhs, rhs);
-
assert(lhs.is_dense());
assert(rhs.is_dense());
- assert(outputType.is_dense());
+ assert(result_type.is_dense());
const auto &lDims = lhs.dimensions();
const auto &rDims = rhs.dimensions();
- const auto &oDims = outputType.dimensions();
+ const auto &oDims = result_type.dimensions();
size_t i = lDims.size();
size_t j = rDims.size();
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h
index 91449f122cd..6408c86902f 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_dimension_combiner.h
@@ -127,6 +127,8 @@ public:
_outputIndex += _outputSize;
}
+ const eval::ValueType result_type;
+
DenseDimensionCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs);
~DenseDimensionCombiner();
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
index fa1e59c87db..e71840f392c 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_apply.hpp
@@ -3,78 +3,45 @@
#pragma once
#include "dense_tensor_apply.h"
-#include "dense_tensor_address_combiner.h"
+#include "dense_dimension_combiner.h"
#include "direct_dense_tensor_builder.h"
namespace vespalib::tensor::dense {
template <typename Function>
std::unique_ptr<Tensor>
-apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
- const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, Function &&func) __attribute__((noinline));
+apply(DenseDimensionCombiner & combiner, DirectDenseTensorBuilder & builder,
+ const DenseTensorView::CellsRef & lhsCells,
+ const DenseTensorView::CellsRef & rhsCells, Function &&func) __attribute__((noinline));
template <typename Function>
std::unique_ptr<Tensor>
-apply(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
- const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, Function &&func)
+apply(DenseDimensionCombiner & combiner, DirectDenseTensorBuilder & builder,
+ const DenseTensorView::CellsRef & lhsCells,
+ const DenseTensorView::CellsRef & rhsCells, Function &&func)
{
- for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
- combiner.updateLeftAndCommon(lhsItr.address());
- if (combiner.updateCommon()) {
- combiner.for_each_right(rhsCells, [&func, &builder, &lhsItr](size_t combined, double rhsCell) {
- builder.insertCell(combined, func(lhsItr.cell(), rhsCell));
- });
+ for (combiner.leftReset(); combiner.leftInRange(); combiner.stepLeft()) {
+ for (combiner.rightReset(); combiner.rightInRange(); combiner.stepRight()) {
+ for (combiner.commonReset(); combiner.commonInRange(); combiner.stepCommon()) {
+ size_t outIdx = combiner.outputIdx();
+ size_t l = combiner.leftIdx();
+ size_t r = combiner.rightIdx();
+ builder.insertCell(outIdx, func(lhsCells[l], rhsCells[r]));
+ }
}
}
return builder.build();
}
-
-template <typename Function>
-std::unique_ptr<Tensor>
-apply_no_rightonly_dimensions(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
- const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells,
- Function &&func) __attribute__((noinline));
-
-template <typename Function>
-std::unique_ptr<Tensor>
-apply_no_rightonly_dimensions(DenseTensorAddressCombiner & combiner, DirectDenseTensorBuilder & builder,
- const DenseTensorView &lhs, const DenseTensorView::CellsRef & rhsCells, Function &&func)
-{
- for (DenseTensorCellsIterator lhsItr = lhs.cellsIterator(); lhsItr.valid(); lhsItr.next()) {
- combiner.updateLeftAndCommon(lhsItr.address());
- if (combiner.updateCommon()) {
- builder.insertCell(combiner.address(), func(lhsItr.cell(), rhsCells[combiner.rightCellIndex()]));
- }
- }
- return builder.build();
-}
-
-template <typename Function>
-std::unique_ptr<Tensor>
-apply(const DenseTensorView &lhs, const DenseTensorView &rhs, Function &&func)
-{
- eval::ValueType resultType = DenseTensorAddressCombiner::combineDimensions(lhs.fast_type(), rhs.fast_type());
- DenseTensorAddressCombiner combiner(resultType, lhs.fast_type(), rhs.fast_type());
- DirectDenseTensorBuilder builder(resultType);
- if (combiner.hasAnyRightOnlyDimensions()) {
- return apply(combiner, builder, lhs, rhs.cellsRef(), std::move(func));
- } else {
- return apply_no_rightonly_dimensions(combiner, builder, lhs, rhs.cellsRef(), std::move(func));
- }
-}
-
template <typename Function>
std::unique_ptr<Tensor>
apply(const DenseTensorView &lhs, const Tensor &rhs, Function &&func)
{
const DenseTensorView *view = dynamic_cast<const DenseTensorView *>(&rhs);
if (view) {
- return apply(lhs, *view, func);
- }
- const DenseTensor *dense = dynamic_cast<const DenseTensor *>(&rhs);
- if (dense) {
- return apply(lhs, *dense, func);
+ DenseDimensionCombiner combiner(lhs.fast_type(), view->fast_type());
+ DirectDenseTensorBuilder builder(combiner.result_type);
+ return apply(combiner, builder, lhs.cellsRef(), view->cellsRef(), std::move(func));
}
return Tensor::UP();
}