diff options
author | Tor Egge <Tor.Egge@broadpark.no> | 2019-02-25 16:52:24 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-02-25 16:52:24 +0100 |
commit | 4b46918b47774d15b4882aff7db693699383ca61 (patch) | |
tree | d7c6498e7390fb8bc2bf8b3c6290c25c5be7fcb6 /eval | |
parent | 50e5898f70d4ea1ece5065b06ea7f3a0755463b9 (diff) | |
parent | e3ab5b19197122709f06636001955e8c84345a0f (diff) |
Merge pull request #8604 from vespa-engine/geirst/remove-and-modify-for-mixed-tensors
Geirst/remove and modify for mixed tensors
Diffstat (limited to 'eval')
5 files changed, 156 insertions, 22 deletions
diff --git a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp index ff59b28d60c..31f17b3eed2 100644 --- a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp +++ b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp @@ -1,5 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/default_tensor_engine.h> @@ -13,16 +14,6 @@ using vespalib::eval::TensorSpec; using vespalib::tensor::test::makeTensor; using namespace vespalib::tensor; -namespace { - -double -replace(double, double b) -{ - return b; -} - -} - void checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec &expect) { @@ -30,7 +21,7 @@ checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec auto updateTensor = makeTensor<SparseTensor>(update); const CellValues cellValues(*updateTensor); - auto actualTensor = sourceTensor->modify(replace, cellValues); + auto actualTensor = sourceTensor->modify(vespalib::eval::operation::Add::f, cellValues); auto actual = actualTensor->toSpec(); auto expectTensor = makeTensor<Tensor>(expect); auto expectPadded = expectTensor->toSpec(); @@ -45,7 +36,7 @@ TEST(TensorModifyTest, sparse_tensors_can_be_modified) TensorSpec("tensor(x{},y{})") .add({{"x","8"},{"y","9"}}, 2), TensorSpec("tensor(x{},y{})") - .add({{"x","8"},{"y","9"}}, 2) + .add({{"x","8"},{"y","9"}}, 13) .add({{"x","9"},{"y","9"}}, 11)); } @@ -57,10 +48,27 @@ TEST(TensorModifyTest, dense_tensors_can_be_modified) TensorSpec("tensor(x{},y{})") .add({{"x","8"},{"y","9"}}, 2), TensorSpec("tensor(x[10],y[10])") - .add({{"x",8},{"y",9}}, 2) + .add({{"x",8},{"y",9}}, 13) .add({{"x",9},{"y",9}}, 11)); } +TEST(TensorModifyTest, mixed_tensors_can_be_modified) +{ + checkUpdate(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 5), + TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","0"}}, 6) + .add({{"x","b"},{"y","1"}}, 7), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 8) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 12)); +} + TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells) { checkUpdate(TensorSpec("tensor(x{},y{})") @@ -70,7 +78,7 @@ TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells) .add({{"x","7"},{"y","9"}}, 2) .add({{"x","8"},{"y","9"}}, 2), TensorSpec("tensor(x{},y{})") - .add({{"x","8"},{"y","9"}}, 2) + .add({{"x","8"},{"y","9"}}, 13) .add({{"x","9"},{"y","9"}}, 11)); } @@ -83,8 +91,21 @@ TEST(TensorModifyTest, dense_tensors_ignore_updates_to_out_of_range_cells) .add({{"x","8"},{"y","9"}}, 2) .add({{"x","10"},{"y","9"}}, 2), TensorSpec("tensor(x[10],y[10])") - .add({{"x",8},{"y",9}}, 2) + .add({{"x",8},{"y",9}}, 13) .add({{"x",9},{"y",9}}, 11)); } +TEST(TensorModifyTest, mixed_tensors_ignore_updates_to_missing_or_out_of_range_cells) +{ + checkUpdate(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3), + TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","2"}}, 4) + .add({{"x","c"},{"y","0"}}, 5), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3)); +} + GTEST_MAIN_RUN_ALL_TESTS diff --git a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp index cb28019c4ee..df21b46691d 100644 --- a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp +++ b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp @@ -43,4 +43,53 @@ TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_sparse_tensor) TensorSpec("tensor(x{},y{})")); } +TEST(TensorRemoveTest, cells_can_be_removed_from_a_mixed_tensor) +{ + assertRemove(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 5), + TensorSpec("tensor(x{})") + .add({{"x","b"}}, 1) + .add({{"x","c"}}, 1), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3)); + + assertRemove(TensorSpec("tensor(x{},y{},z[2])") + .add({{"x","a"},{"y","c"},{"z",0}}, 2) + .add({{"x","a"},{"y","c"},{"z",1}}, 3) + .add({{"x","b"},{"y","c"},{"z",0}}, 4) + .add({{"x","b"},{"y","c"},{"z",1}}, 5), + TensorSpec("tensor(x{},y{})") + .add({{"x","b"},{"y","c"}}, 1) + .add({{"x","c"},{"y","c"}}, 1), + TensorSpec("tensor(x{},y{},z[2])") + .add({{"x","a"},{"y","c"},{"z",0}}, 2) + .add({{"x","a"},{"y","c"},{"z",1}}, 3)); + + assertRemove(TensorSpec("tensor(x{},y[1],z[2])") + .add({{"x","a"},{"y",0},{"z",0}}, 2) + .add({{"x","a"},{"y",0},{"z",1}}, 3) + .add({{"x","b"},{"y",0},{"z",0}}, 4) + .add({{"x","b"},{"y",0},{"z",1}}, 5), + TensorSpec("tensor(x{})") + .add({{"x","b"}}, 1) + .add({{"x","c"}}, 1), + TensorSpec("tensor(x{},y[1],z[2])") + .add({{"x","a"},{"y",0},{"z",0}}, 2) + .add({{"x","a"},{"y",0},{"z",1}}, 3)); +} + +TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_mixed_tensor) +{ + assertRemove(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3), + TensorSpec("tensor(x{})") + .add({{"x","a"}}, 1), + TensorSpec("tensor(x{},y[2])")); +} + GTEST_MAIN_RUN_ALL_TESTS diff --git a/eval/src/vespa/eval/tensor/cell_values.h b/eval/src/vespa/eval/tensor/cell_values.h index dabdcde3294..4b8fd33376a 100644 --- a/eval/src/vespa/eval/tensor/cell_values.h +++ b/eval/src/vespa/eval/tensor/cell_values.h @@ -23,6 +23,10 @@ public: void accept(TensorVisitor &visitor) const { _tensor.accept(visitor); } + + eval::TensorSpec toSpec() const { + return _tensor.toSpec(); + } }; } diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index 7eebff1f010..c182c09c6b0 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -2,10 +2,10 @@ #pragma once +#include "sparse_tensor_address_ref.h" #include <vespa/eval/tensor/cell_function.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/tensor_address.h> -#include "sparse_tensor_address_ref.h" #include <vespa/eval/tensor/types.h> #include <vespa/vespalib/stllike/hash_map.h> #include <vespa/vespalib/stllike/string.h> diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 1268d6fa9cb..a982a4b0fe1 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -1,8 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "wrapped_simple_tensor.h" +#include "cell_values.h" #include "tensor_address_builder.h" #include "tensor_visitor.h" +#include "wrapped_simple_tensor.h" #include <vespa/eval/eval/simple_tensor_engine.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/vespalib/util/stringfmt.h> @@ -80,10 +81,42 @@ WrappedSimpleTensor::reduce(join_fun_t, const std::vector<vespalib::string> &) c LOG_ABORT("should not be reached"); } +namespace { + +TensorSpec::Address +convertToOnlyMappedDimensions(const TensorSpec::Address &address) +{ + TensorSpec::Address result; + for (const auto &elem : address) { + if (elem.second.is_indexed()) { + result.emplace(std::make_pair(elem.first, + TensorSpec::Label(vespalib::make_string("%zu", elem.second.index)))); + } else { + result.emplace(elem); + } + } + return result; +} + +} + std::unique_ptr<Tensor> -WrappedSimpleTensor::modify(join_fun_t, const CellValues &) const +WrappedSimpleTensor::modify(join_fun_t op, const CellValues &cellValues) const { - LOG_ABORT("should not be reached"); + TensorSpec oldTensor = toSpec(); + TensorSpec toModify = cellValues.toSpec(); + TensorSpec result(type().to_spec()); + + for (const auto &cell : oldTensor.cells()) { + TensorSpec::Address mappedAddress = convertToOnlyMappedDimensions(cell.first); + auto itr = toModify.cells().find(mappedAddress); + if (itr != toModify.cells().end()) { + result.add(cell.first, op(cell.second, itr->second)); + } else { + result.add(cell.first, cell.second); + } + } + return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result)); } std::unique_ptr<Tensor> @@ -114,10 +147,37 @@ WrappedSimpleTensor::add(const Tensor &arg) const return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result)); } +namespace { + +TensorSpec::Address +extractMappedDimensions(const TensorSpec::Address &address) +{ + TensorSpec::Address result; + for (const auto &elem : address) { + if (elem.second.is_mapped()) { + result.emplace(elem); + } + } + return result; +} + +} + std::unique_ptr<Tensor> -WrappedSimpleTensor::remove(const CellValues &) const +WrappedSimpleTensor::remove(const CellValues &cellAddresses) const { - LOG_ABORT("should not be reached"); + TensorSpec oldTensor = toSpec(); + TensorSpec toRemove = cellAddresses.toSpec(); + TensorSpec result(type().to_spec()); + + for (const auto &cell : oldTensor.cells()) { + TensorSpec::Address mappedAddress = extractMappedDimensions(cell.first); + auto itr = toRemove.cells().find(mappedAddress); + if (itr == toRemove.cells().end()) { + result.add(cell.first, cell.second); + } + } + return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result)); } -} // namespace vespalib::tensor +} |