From 0e1c348bb3e27fc762806f7dbc444474e9036615 Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Mon, 25 Feb 2019 14:38:30 +0000 Subject: Support remove operation on mixed tensors. --- .../tensor_remove_operation_test.cpp | 49 ++++++++++++++++++++++ eval/src/vespa/eval/tensor/cell_values.h | 4 ++ eval/src/vespa/eval/tensor/sparse/sparse_tensor.h | 2 +- .../vespa/eval/tensor/wrapped_simple_tensor.cpp | 36 ++++++++++++++-- 4 files changed, 86 insertions(+), 5 deletions(-) (limited to 'eval/src') diff --git a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp index cb28019c4ee..df21b46691d 100644 --- a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp +++ b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp @@ -43,4 +43,53 @@ TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_sparse_tensor) TensorSpec("tensor(x{},y{})")); } +TEST(TensorRemoveTest, cells_can_be_removed_from_a_mixed_tensor) +{ + assertRemove(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 5), + TensorSpec("tensor(x{})") + .add({{"x","b"}}, 1) + .add({{"x","c"}}, 1), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3)); + + assertRemove(TensorSpec("tensor(x{},y{},z[2])") + .add({{"x","a"},{"y","c"},{"z",0}}, 2) + .add({{"x","a"},{"y","c"},{"z",1}}, 3) + .add({{"x","b"},{"y","c"},{"z",0}}, 4) + .add({{"x","b"},{"y","c"},{"z",1}}, 5), + TensorSpec("tensor(x{},y{})") + .add({{"x","b"},{"y","c"}}, 1) + .add({{"x","c"},{"y","c"}}, 1), + TensorSpec("tensor(x{},y{},z[2])") + .add({{"x","a"},{"y","c"},{"z",0}}, 2) + .add({{"x","a"},{"y","c"},{"z",1}}, 3)); + + assertRemove(TensorSpec("tensor(x{},y[1],z[2])") + .add({{"x","a"},{"y",0},{"z",0}}, 2) + .add({{"x","a"},{"y",0},{"z",1}}, 3) + .add({{"x","b"},{"y",0},{"z",0}}, 4) + .add({{"x","b"},{"y",0},{"z",1}}, 5), + TensorSpec("tensor(x{})") + .add({{"x","b"}}, 1) + .add({{"x","c"}}, 1), + TensorSpec("tensor(x{},y[1],z[2])") + .add({{"x","a"},{"y",0},{"z",0}}, 2) + .add({{"x","a"},{"y",0},{"z",1}}, 3)); +} + +TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_mixed_tensor) +{ + assertRemove(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3), + TensorSpec("tensor(x{})") + .add({{"x","a"}}, 1), + TensorSpec("tensor(x{},y[2])")); +} + GTEST_MAIN_RUN_ALL_TESTS diff --git a/eval/src/vespa/eval/tensor/cell_values.h b/eval/src/vespa/eval/tensor/cell_values.h index dabdcde3294..4b8fd33376a 100644 --- a/eval/src/vespa/eval/tensor/cell_values.h +++ b/eval/src/vespa/eval/tensor/cell_values.h @@ -23,6 +23,10 @@ public: void accept(TensorVisitor &visitor) const { _tensor.accept(visitor); } + + eval::TensorSpec toSpec() const { + return _tensor.toSpec(); + } }; } diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index 7eebff1f010..c182c09c6b0 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -2,10 +2,10 @@ #pragma once +#include "sparse_tensor_address_ref.h" #include #include #include -#include "sparse_tensor_address_ref.h" #include #include #include diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 1268d6fa9cb..9d451c4639d 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -1,8 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "wrapped_simple_tensor.h" +#include "cell_values.h" #include "tensor_address_builder.h" #include "tensor_visitor.h" +#include "wrapped_simple_tensor.h" #include #include #include @@ -114,10 +115,37 @@ WrappedSimpleTensor::add(const Tensor &arg) const return std::make_unique(SimpleTensor::create(result)); } +namespace { + +TensorSpec::Address +extractMappedDimensions(const TensorSpec::Address &address) +{ + TensorSpec::Address result; + for (const auto &elem : address) { + if (elem.second.is_mapped()) { + result.emplace(elem); + } + } + return result; +} + +} + std::unique_ptr -WrappedSimpleTensor::remove(const CellValues &) const +WrappedSimpleTensor::remove(const CellValues &cellAddresses) const { - LOG_ABORT("should not be reached"); + TensorSpec oldTensor = toSpec(); + TensorSpec toRemove = cellAddresses.toSpec(); + TensorSpec result(type().to_spec()); + + for (const auto &cell : oldTensor.cells()) { + TensorSpec::Address mappedAddress = extractMappedDimensions(cell.first); + auto itr = toRemove.cells().find(mappedAddress); + if (itr == toRemove.cells().end()) { + result.add(cell.first, cell.second); + } + } + return std::make_unique(SimpleTensor::create(result)); } -} // namespace vespalib::tensor +} -- cgit v1.2.3