diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2017-12-18 16:35:32 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2017-12-18 16:35:32 +0100 |
commit | 2a6a5e3abfc556b4c9a19b045b786f53ce337c5f (patch) | |
tree | 164e37872db6b5726c6a932c99371da91ebf84e4 /eval | |
parent | 4c2e8e00ea79f039ced9b41504bb5ad304687706 (diff) |
Consider overlapping dimensions when estimating resulting tensor space.
Diffstat (limited to 'eval')
5 files changed, 40 insertions, 24 deletions
diff --git a/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp b/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp index b45a06579cb..708c2f761f7 100644 --- a/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/sparse_tensor_builder/sparse_tensor_builder_test.cpp @@ -2,9 +2,11 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/eval/tensor/sparse/sparse_tensor_builder.h> +#include <vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h> #include <vespa/vespalib/test/insertion_operators.h> using namespace vespalib::tensor; +using namespace vespalib::tensor::sparse; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; @@ -57,10 +59,8 @@ TEST("require that tensor can be constructed") const ValueType &type = sparseTensor.type(); const SparseTensor::Cells &cells = sparseTensor.cells(); EXPECT_EQUAL(2u, cells.size()); - assertCellValue(10, TensorAddress({{"a","1"},{"b","2"}}), - type, cells); - assertCellValue(20, TensorAddress({{"c","3"},{"d","4"}}), - type, cells); + assertCellValue(10, TensorAddress({{"a","1"},{"b","2"}}), type, cells); + assertCellValue(20, TensorAddress({{"c","3"},{"d","4"}}), type, cells); } TEST("require that tensor can be converted to tensor spec") @@ -94,6 +94,18 @@ TEST("require that dimensions are extracted") EXPECT_EQUAL("tensor(a{},b{},c{})", sparseTensor.type().to_spec()); } +void verifyAddressCombiner(const ValueType & a, const ValueType & b, size_t numDim, size_t numOverlapping) { + TensorAddressCombiner combiner(a, b); + EXPECT_EQUAL(numDim, combiner.numDimensions()); + EXPECT_EQUAL(numOverlapping, combiner.numOverlappingDimensions()); +} +TEST("Test sparse tensor address combiner") { + verifyAddressCombiner(ValueType::tensor_type({{"a"}}), ValueType::tensor_type({{"b"}}), 2, 0); + verifyAddressCombiner(ValueType::tensor_type({{"a"}, {"b"}}), ValueType::tensor_type({{"b"}}), 2, 1); + verifyAddressCombiner(ValueType::tensor_type({{"a"}, {"b"}}), ValueType::tensor_type({{"b"}, {"c"}}), 3, 1); + +} + TEST("Test essential object sizes") { EXPECT_EQUAL(16u, sizeof(SparseTensorAddressRef)); EXPECT_EQUAL(24u, sizeof(std::pair<SparseTensorAddressRef, double>)); diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 2988cc5204e..e304f51436f 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -6,8 +6,7 @@ #include <vector> #include <memory> -namespace vespalib { -namespace eval { +namespace vespalib::eval { /** * The type of a Value. This is used for type-resolution during @@ -91,5 +90,4 @@ public: std::ostream &operator<<(std::ostream &os, const ValueType &type); -} // namespace vespalib::eval -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp index b386ec82528..9693832ea88 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.cpp @@ -5,12 +5,9 @@ #include <vespa/eval/eval/value_type.h> #include <cassert> -namespace vespalib { -namespace tensor { -namespace sparse { +namespace vespalib::tensor::sparse { -TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, - const eval::ValueType &rhs) +TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs) { auto rhsItr = rhs.dimensions().cbegin(); auto rhsItrEnd = rhs.dimensions().cend(); @@ -32,8 +29,17 @@ TensorAddressCombiner::TensorAddressCombiner(const eval::ValueType &lhs, } } -TensorAddressCombiner::~TensorAddressCombiner() -{ +TensorAddressCombiner::~TensorAddressCombiner() = default; + +size_t +TensorAddressCombiner::numOverlappingDimensions() const { + size_t count = 0; + for (AddressOp op : _ops) { + if (op == AddressOp::BOTH) { + count++; + } + } + return count; } bool @@ -60,11 +66,7 @@ TensorAddressCombiner::combine(SparseTensorAddressRef lhsRef, add(lhsLabel); } } - assert(!lhs.valid()); - assert(!rhs.valid()); return true; } -} // namespace vespalib::tensor::sparse -} // namespace vespalib::tensor -} // namespace vespalib +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h index 307ebe59ba7..491d5c9be8b 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_address_combiner.h @@ -25,12 +25,12 @@ class TensorAddressCombiner : public SparseTensorAddressBuilder std::vector<AddressOp> _ops; public: - TensorAddressCombiner(const eval::ValueType &lhs, - const eval::ValueType &rhs); - + TensorAddressCombiner(const eval::ValueType &lhs, const eval::ValueType &rhs); ~TensorAddressCombiner(); bool combine(SparseTensorAddressRef lhsRef, SparseTensorAddressRef rhsRef); + size_t numOverlappingDimensions() const; + size_t numDimensions() const { return _ops.size(); } }; diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp index 0b9e127dd82..2027e0afc9d 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_apply.hpp @@ -15,7 +15,11 @@ apply(const SparseTensor &lhs, const SparseTensor &rhs, Function &&func) { DirectTensorBuilder<SparseTensor> builder(lhs.combineDimensionsWith(rhs)); TensorAddressCombiner addressCombiner(lhs.fast_type(), rhs.fast_type()); - builder.reserve((lhs.cells().size() * rhs.cells())*2); + size_t estimatedCells = (lhs.cells().size() * rhs.cells().size()); + if (addressCombiner.numOverlappingDimensions() != 0) { + estimatedCells = std::min(lhs.cells().size(), rhs.cells().size()); + } + builder.reserve(estimatedCells*2); for (const auto &lhsCell : lhs.cells()) { for (const auto &rhsCell : rhs.cells()) { bool combineSuccess = addressCombiner.combine(lhsCell.first, rhsCell.first); |