diff options
8 files changed, 89 insertions, 43 deletions
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 907cab7dbc6..61ec9f1ccf9 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -73,6 +73,7 @@ vespa_define_module( src/tests/stringfmt src/tests/sync src/tests/tensor/sparse_tensor_builder + src/tests/tensor/dense_tensor_address_combiner src/tests/tensor/dense_tensor_builder src/tests/tensor/dense_tensor_operations src/tests/tensor/tensor_address diff --git a/vespalib/src/tests/tensor/dense_tensor_address_combiner/CMakeLists.txt b/vespalib/src/tests/tensor/dense_tensor_address_combiner/CMakeLists.txt new file mode 100644 index 00000000000..65e7c711b19 --- /dev/null +++ b/vespalib/src/tests/tensor/dense_tensor_address_combiner/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_dense_tensor_address_combiner_test_app TEST + SOURCES + dense_tensor_address_combiner_test.cpp + DEPENDS + vespalib + vespalib_vespalib_tensor +) +vespa_add_test(NAME vespalib_dense_tensor_address_combiner_test_app COMMAND vespalib_dense_tensor_address_combiner_test_app) diff --git a/vespalib/src/tests/tensor/dense_tensor_address_combiner/FILES b/vespalib/src/tests/tensor/dense_tensor_address_combiner/FILES new file mode 100644 index 00000000000..0a49bd4647b --- /dev/null +++ b/vespalib/src/tests/tensor/dense_tensor_address_combiner/FILES @@ -0,0 +1 @@ +dense_tensor_address_combiner_test.cpp diff --git a/vespalib/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp b/vespalib/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp new file mode 100644 index 00000000000..1192469e006 --- /dev/null +++ b/vespalib/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp @@ -0,0 +1,36 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/tensor/dense/dense_tensor_address_combiner.h> +#include <vespa/vespalib/test/insertion_operators.h> + +using namespace vespalib::tensor; +using DimensionsMeta = DenseTensor::DimensionsMeta; + +std::ostream & +operator<<(std::ostream &out, const DenseTensor::DimensionMeta &dimMeta) +{ + out << dimMeta.dimension() << "[" << dimMeta.size() << "]"; + return out; +} + +DimensionsMeta +combine(const DimensionsMeta &lhs, const DimensionsMeta &rhs) +{ + return DenseTensorAddressCombiner::combineDimensions(lhs, rhs); +} + +TEST("require that dimensions can be combined") +{ + EXPECT_EQUAL(DimensionsMeta({{"a", 3}, {"b", 5}}), combine({{"a", 3}}, {{"b", 5}})); + EXPECT_EQUAL(DimensionsMeta({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 5}}, {{"b", 5}})); + EXPECT_EQUAL(DimensionsMeta({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 7}}, {{"b", 5}})); + EXPECT_EQUAL(DimensionsMeta({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}), + combine({{"a", 3}, {"c", 5}, {"d", 7}}, + {{"b", 11}, {"c", 13}, {"e", 17}})); + EXPECT_EQUAL(DimensionsMeta({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}), + combine({{"b", 11}, {"c", 13}, {"e", 17}}, + {{"a", 3}, {"c", 5}, {"d", 7}})); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/tensor/dense_tensor_operations/dense_tensor_operations_test.cpp b/vespalib/src/tests/tensor/dense_tensor_operations/dense_tensor_operations_test.cpp index aea81ad6b77..b5d0769d37d 100644 --- a/vespalib/src/tests/tensor/dense_tensor_operations/dense_tensor_operations_test.cpp +++ b/vespalib/src/tests/tensor/dense_tensor_operations/dense_tensor_operations_test.cpp @@ -212,19 +212,22 @@ template <typename FixtureType> void testTensorAdd(FixtureType &f) { - f.assertAdd({},{},{}, false); - f.assertAdd({ {{{"x",0}}, 8} }, - { {{{"x",0}}, 3} }, - { {{{"x",0}}, 5} }); - f.assertAdd({ {{{"x",0}}, -2} }, - { {{{"x",0}}, 3} }, - { {{{"x",0}}, -5} }); - f.assertAdd({ {{{"x",0}}, 10}, {{{"x",1}}, 16} }, - { {{{"x",0}}, 3}, {{{"x",1}}, 5} }, - { {{{"x",0}}, 7}, {{{"x",1}}, 11} }); - f.assertAdd({ {{{"x",0},{"y",0}}, 8} }, - { {{{"x",0},{"y",0}}, 3} }, - { {{{"x",0},{"y",0}}, 5} }); + TEST_DO(f.assertAdd({},{},{}, false)); + TEST_DO(f.assertAdd({ {{{"x",0}}, 8} }, + { {{{"x",0}}, 3} }, + { {{{"x",0}}, 5} })); + TEST_DO(f.assertAdd({ {{{"x",0}}, -2} }, + { {{{"x",0}}, 3} }, + { {{{"x",0}}, -5} })); + TEST_DO(f.assertAdd({ {{{"x",0}}, 10}, {{{"x",1}}, 16} }, + { {{{"x",0}}, 3}, {{{"x",1}}, 5} }, + { {{{"x",0}}, 7}, {{{"x",1}}, 11} })); + TEST_DO(f.assertAdd({ {{{"x",0},{"y",0}}, 8} }, + { {{{"x",0},{"y",0}}, 3} }, + { {{{"x",0},{"y",0}}, 5} })); + TEST_DO(f.assertAdd({ {{{"x",0}}, 3} }, + { {{{"x",0}}, 3} }, + { {{{"x",1}}, 5} })); } template <typename FixtureType> diff --git a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp index c155d8ed1e3..e0855691bc9 100644 --- a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp +++ b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.cpp @@ -211,10 +211,9 @@ DenseTensor::add(const Tensor &arg) const if (!rhs) { return Tensor::UP(); } - checkDimensions(*this, *rhs, "add"); - return joinDenseTensors(*this, *rhs, - [](double lhsValue, double rhsValue) - { return lhsValue + rhsValue; }); + return dense::apply(*this, *rhs, + [](double lhsValue, double rhsValue) + { return lhsValue + rhsValue; }); } Tensor::UP diff --git a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h index c8a82bfe73e..5b7912c43fc 100644 --- a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h +++ b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor.h @@ -9,7 +9,8 @@ namespace vespalib { namespace tensor { /** - * TODO + * A dense tensor where all dimensions are indexed. + * Tensor cells are stored in an underlying array according to the order of the dimensions. */ class DenseTensor : public Tensor { diff --git a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor_address_combiner.cpp b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor_address_combiner.cpp index 2ad4228e0ec..88fe86ca9e6 100644 --- a/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor_address_combiner.cpp +++ b/vespalib/src/vespa/vespalib/tensor/dense/dense_tensor_address_combiner.cpp @@ -9,6 +9,7 @@ namespace vespalib { namespace tensor { using Address = DenseTensorAddressCombiner::Address; +using DimensionMeta = DenseTensor::DimensionMeta; using DimensionsMeta = DenseTensorAddressCombiner::DimensionsMeta; namespace { @@ -88,35 +89,30 @@ DenseTensorAddressCombiner::combine(const CellsIterator &lhsItr, return true; } -namespace { - -void -validateDimensionsMeta(const DimensionsMeta &dimensionsMeta) -{ - for (size_t i = 1; i < dimensionsMeta.size(); ++i) { - const auto &prevDimMeta = dimensionsMeta[i-1]; - const auto &currDimMeta = dimensionsMeta[i]; - if ((prevDimMeta.dimension() == currDimMeta.dimension()) && - (prevDimMeta.size() != currDimMeta.size())) - { - throw IllegalArgumentException(make_string( - "Shared dimension '%s' has mis-matching label ranges: " - "[0, %zu> vs [0, %zu>. This is not supported.", - prevDimMeta.dimension().c_str(), prevDimMeta.size(), currDimMeta.size())); - } - } -} - -} - DimensionsMeta DenseTensorAddressCombiner::combineDimensions(const DimensionsMeta &lhs, const DimensionsMeta &rhs) { + // NOTE: both lhs and rhs are sorted according to dimension names. DimensionsMeta result; - std::set_union(lhs.cbegin(), lhs.cend(), - rhs.cbegin(), rhs.cend(), - std::back_inserter(result)); - validateDimensionsMeta(result); + auto lhsItr = lhs.cbegin(); + auto rhsItr = rhs.cbegin(); + while (lhsItr != lhs.end() && rhsItr != rhs.end()) { + if (lhsItr->dimension() == rhsItr->dimension()) { + result.emplace_back(DimensionMeta(lhsItr->dimension(), std::min(lhsItr->size(), rhsItr->size()))); + ++lhsItr; + ++rhsItr; + } else if (lhsItr->dimension() < rhsItr->dimension()) { + result.emplace_back(*lhsItr++); + } else { + result.emplace_back(*rhsItr++); + } + } + while (lhsItr != lhs.end()) { + result.emplace_back(*lhsItr++); + } + while (rhsItr != rhs.end()) { + result.emplace_back(*rhsItr++); + } return result; } |