From c6e92173cf30de539ef1afa4f62585efaa4b9050 Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Wed, 20 Feb 2019 14:19:27 +0000 Subject: Implement remove operation for sparse tensor. --- .../tensor/tensor_remove_operation/CMakeLists.txt | 8 ++++ .../tensor_remove_operation_test.cpp | 46 ++++++++++++++++++++++ .../vespa/eval/tensor/dense/dense_tensor_view.cpp | 6 +++ .../vespa/eval/tensor/dense/dense_tensor_view.h | 1 + eval/src/vespa/eval/tensor/sparse/CMakeLists.txt | 3 +- .../src/vespa/eval/tensor/sparse/sparse_tensor.cpp | 12 ++++++ eval/src/vespa/eval/tensor/sparse/sparse_tensor.h | 1 + .../eval/tensor/sparse/sparse_tensor_remove.cpp | 33 ++++++++++++++++ .../eval/tensor/sparse/sparse_tensor_remove.h | 32 +++++++++++++++ eval/src/vespa/eval/tensor/tensor.h | 6 +++ .../vespa/eval/tensor/wrapped_simple_tensor.cpp | 6 +++ eval/src/vespa/eval/tensor/wrapped_simple_tensor.h | 1 + 12 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt create mode 100644 eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp create mode 100644 eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp create mode 100644 eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h (limited to 'eval/src') diff --git a/eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt b/eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt new file mode 100644 index 00000000000..8dfb8181f2b --- /dev/null +++ b/eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_tensor_remove_operation_test_app TEST + SOURCES + tensor_remove_operation_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_tensor_remove_operation_test_app COMMAND eval_tensor_remove_operation_test_app) diff --git a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp new file mode 100644 index 00000000000..8b0c44a6e06 --- /dev/null +++ b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp @@ -0,0 +1,46 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include +#include +#include + +using vespalib::eval::Value; +using vespalib::eval::TensorSpec; +using vespalib::tensor::test::makeTensor; +using namespace vespalib::tensor; + +void +assertRemove(const TensorSpec &source, const TensorSpec &arg, const TensorSpec &expected) +{ + auto sourceTensor = makeTensor(source); + auto argTensor = makeTensor(arg); + auto resultTensor = sourceTensor->remove(CellValues(*argTensor)); + auto actual = resultTensor->toSpec(); + EXPECT_EQUAL(actual, expected); +} + +TEST("require that cells can be removed from a sparse tensor") +{ + assertRemove(TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","b"}}, 2) + .add({{"x","c"},{"y","d"}}, 3), + TensorSpec("tensor(x{},y{})") + .add({{"x","c"},{"y","d"}}, 1) + .add({{"x","e"},{"y","f"}}, 1), + TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","b"}}, 2)); +} + +TEST("require that all cells can be removed from a sparse tensor") +{ + assertRemove(TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","b"}}, 2), + TensorSpec("tensor(x{},y{})") + .add({{"x","a"},{"y","b"}}, 1), + TensorSpec("tensor(x{},y{})")); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp index 6243f79a971..164ec042384 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp @@ -299,4 +299,10 @@ DenseTensorView::add(const Tensor &) const LOG_ABORT("should not be reached"); } +std::unique_ptr +DenseTensorView::remove(const CellValues &) const +{ + LOG_ABORT("should not be reached"); +} + } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index f470e9d374f..11ed9639cc6 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -55,6 +55,7 @@ public: Tensor::UP reduce(join_fun_t op, const std::vector &dimensions) const override; std::unique_ptr modify(join_fun_t op, const CellValues &cellValues) const override; std::unique_ptr add(const Tensor &arg) const override; + std::unique_ptr remove(const CellValues &) const override; bool equals(const Tensor &arg) const override; Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; diff --git a/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt b/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt index d50c6d5db10..2d142d98ba1 100644 --- a/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt +++ b/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt @@ -8,8 +8,9 @@ vespa_add_library(eval_tensor_sparse OBJECT sparse_tensor_address_padder.cpp sparse_tensor_address_reducer.cpp sparse_tensor_address_ref.cpp + sparse_tensor_builder.cpp sparse_tensor_match.cpp sparse_tensor_modify.cpp - sparse_tensor_builder.cpp + sparse_tensor_remove.cpp sparse_tensor_unsorted_address_builder.cpp ) diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index e3ee9593d80..ded9310b450 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -7,6 +7,7 @@ #include "sparse_tensor_match.h" #include "sparse_tensor_modify.h" #include "sparse_tensor_reduce.hpp" +#include "sparse_tensor_remove.h" #include #include #include @@ -215,6 +216,17 @@ SparseTensor::add(const Tensor &arg) const return adder.build(); } +std::unique_ptr +SparseTensor::remove(const CellValues &cellAddresses) const +{ + Cells cells; + Stash stash; + copyCells(cells, _cells, stash); + SparseTensorRemove remover(_type, std::move(cells), std::move(stash)); + cellAddresses.accept(remover); + return remover.build(); +} + } VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(vespalib::tensor::SparseTensorAddressRef, double, vespalib::hash, diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index 107cba7a673..7eebff1f010 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -47,6 +47,7 @@ public: Tensor::UP reduce(join_fun_t op, const std::vector &dimensions) const override; std::unique_ptr modify(join_fun_t op, const CellValues &cellValues) const override; std::unique_ptr add(const Tensor &arg) const override; + std::unique_ptr remove(const CellValues &cellAddresses) const override; bool equals(const Tensor &arg) const override; Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp new file mode 100644 index 00000000000..76af1e3b5fb --- /dev/null +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp @@ -0,0 +1,33 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "sparse_tensor_remove.h" +#include + +namespace vespalib::tensor { + +SparseTensorRemove::SparseTensorRemove(const eval::ValueType &type, Cells &&cells, Stash &&stash) + : _type(type), + _cells(std::move(cells)), + _stash(std::move(stash)), + _addressBuilder() +{ +} + +SparseTensorRemove::~SparseTensorRemove() = default; + +void +SparseTensorRemove::visit(const TensorAddress &address, double value) +{ + (void) value; + _addressBuilder.populate(_type, address); + auto addressRef = _addressBuilder.getAddressRef(); + _cells.erase(addressRef); +} + +std::unique_ptr +SparseTensorRemove::build() +{ + return std::make_unique(std::move(_type), std::move(_cells), std::move(_stash)); +} + +} diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h new file mode 100644 index 00000000000..3d5905d8f41 --- /dev/null +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h @@ -0,0 +1,32 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "sparse_tensor.h" +#include "sparse_tensor_address_builder.h" +#include + +namespace vespalib::tensor { + +/** + * This class handles a tensor remove operation on a sparse tensor. + * + * Creates a new tensor by removing the cells matching the cell addresses visited. + * The value associated with the address is ignored. + */ +class SparseTensorRemove : public TensorVisitor { +private: + using Cells = SparseTensor::Cells; + eval::ValueType _type; + Cells _cells; + Stash _stash; + SparseTensorAddressBuilder _addressBuilder; + +public: + SparseTensorRemove(const eval::ValueType &type, Cells &&cells, Stash &&stash); + ~SparseTensorRemove(); + void visit(const TensorAddress &address, double value) override; + std::unique_ptr build(); +}; + +} diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h index cdb9d90d3a3..4061ed9c115 100644 --- a/eval/src/vespa/eval/tensor/tensor.h +++ b/eval/src/vespa/eval/tensor/tensor.h @@ -49,6 +49,12 @@ public: */ virtual std::unique_ptr add(const Tensor &arg) const = 0; + /** + * Creates a new tensor by removing the cells matching the given cell addresses. + * The value associated with the address is ignored. + */ + virtual std::unique_ptr remove(const CellValues &cellAddresses) const = 0; + virtual bool equals(const Tensor &arg) const = 0; // want to remove, but needed by document virtual Tensor::UP clone() const = 0; // want to remove, but needed by document virtual eval::TensorSpec toSpec() const = 0; diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 66fd2978a53..9df59a63873 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -89,4 +89,10 @@ WrappedSimpleTensor::add(const Tensor &) const LOG_ABORT("should not be reached"); } +std::unique_ptr +WrappedSimpleTensor::remove(const CellValues &) const +{ + LOG_ABORT("should not be reached"); +} + } // namespace vespalib::tensor diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h index 2d877b6fbbc..e7ffe7a755f 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h @@ -40,6 +40,7 @@ public: Tensor::UP reduce(join_fun_t, const std::vector &) const override; std::unique_ptr modify(join_fun_t, const CellValues &) const override; std::unique_ptr add(const Tensor &arg) const override; + std::unique_ptr remove(const CellValues &) const override; }; } // namespace vespalib::tensor -- cgit v1.2.3