summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-20 14:19:27 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-02-20 14:19:27 +0000
commitc6e92173cf30de539ef1afa4f62585efaa4b9050 (patch)
treee5bcbba39d9e025dd99f0df10ecdfa75dc50b1be /eval/src
parent85e394563c8b711a1a0307c8ac5953c1817f5629 (diff)
Implement remove operation for sparse tensor.
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt8
-rw-r--r--eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp46
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h1
-rw-r--r--eval/src/vespa/eval/tensor/sparse/CMakeLists.txt3
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp12
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h1
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp33
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h32
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h6
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.h1
12 files changed, 154 insertions, 1 deletions
diff --git a/eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt b/eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt
new file mode 100644
index 00000000000..8dfb8181f2b
--- /dev/null
+++ b/eval/src/tests/tensor/tensor_remove_operation/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_tensor_remove_operation_test_app TEST
+ SOURCES
+ tensor_remove_operation_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_tensor_remove_operation_test_app COMMAND eval_tensor_remove_operation_test_app)
diff --git a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
new file mode 100644
index 00000000000..8b0c44a6e06
--- /dev/null
+++ b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
@@ -0,0 +1,46 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/tensor_spec.h>
+#include <vespa/eval/tensor/cell_values.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/tensor/sparse/sparse_tensor.h>
+#include <vespa/eval/tensor/test/test_utils.h>
+#include <vespa/vespalib/testkit/test_kit.h>
+
+using vespalib::eval::Value;
+using vespalib::eval::TensorSpec;
+using vespalib::tensor::test::makeTensor;
+using namespace vespalib::tensor;
+
+void
+assertRemove(const TensorSpec &source, const TensorSpec &arg, const TensorSpec &expected)
+{
+ auto sourceTensor = makeTensor<Tensor>(source);
+ auto argTensor = makeTensor<SparseTensor>(arg);
+ auto resultTensor = sourceTensor->remove(CellValues(*argTensor));
+ auto actual = resultTensor->toSpec();
+ EXPECT_EQUAL(actual, expected);
+}
+
+TEST("require that cells can be removed from a sparse tensor")
+{
+ assertRemove(TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","b"}}, 2)
+ .add({{"x","c"},{"y","d"}}, 3),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","c"},{"y","d"}}, 1)
+ .add({{"x","e"},{"y","f"}}, 1),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","b"}}, 2));
+}
+
+TEST("require that all cells can be removed from a sparse tensor")
+{
+ assertRemove(TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","b"}}, 2),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","b"}}, 1),
+ TensorSpec("tensor(x{},y{})"));
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index 6243f79a971..164ec042384 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -299,4 +299,10 @@ DenseTensorView::add(const Tensor &) const
LOG_ABORT("should not be reached");
}
+std::unique_ptr<Tensor>
+DenseTensorView::remove(const CellValues &) const
+{
+ LOG_ABORT("should not be reached");
+}
+
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index f470e9d374f..11ed9639cc6 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -55,6 +55,7 @@ public:
Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override;
std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override;
std::unique_ptr<Tensor> add(const Tensor &arg) const override;
+ std::unique_ptr<Tensor> remove(const CellValues &) const override;
bool equals(const Tensor &arg) const override;
Tensor::UP clone() const override;
eval::TensorSpec toSpec() const override;
diff --git a/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt b/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt
index d50c6d5db10..2d142d98ba1 100644
--- a/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt
+++ b/eval/src/vespa/eval/tensor/sparse/CMakeLists.txt
@@ -8,8 +8,9 @@ vespa_add_library(eval_tensor_sparse OBJECT
sparse_tensor_address_padder.cpp
sparse_tensor_address_reducer.cpp
sparse_tensor_address_ref.cpp
+ sparse_tensor_builder.cpp
sparse_tensor_match.cpp
sparse_tensor_modify.cpp
- sparse_tensor_builder.cpp
+ sparse_tensor_remove.cpp
sparse_tensor_unsorted_address_builder.cpp
)
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index e3ee9593d80..ded9310b450 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -7,6 +7,7 @@
#include "sparse_tensor_match.h"
#include "sparse_tensor_modify.h"
#include "sparse_tensor_reduce.hpp"
+#include "sparse_tensor_remove.h"
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/tensor/cell_values.h>
#include <vespa/eval/tensor/tensor_address_builder.h>
@@ -215,6 +216,17 @@ SparseTensor::add(const Tensor &arg) const
return adder.build();
}
+std::unique_ptr<Tensor>
+SparseTensor::remove(const CellValues &cellAddresses) const
+{
+ Cells cells;
+ Stash stash;
+ copyCells(cells, _cells, stash);
+ SparseTensorRemove remover(_type, std::move(cells), std::move(stash));
+ cellAddresses.accept(remover);
+ return remover.build();
+}
+
}
VESPALIB_HASH_MAP_INSTANTIATE_H_E_M(vespalib::tensor::SparseTensorAddressRef, double, vespalib::hash<vespalib::tensor::SparseTensorAddressRef>,
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index 107cba7a673..7eebff1f010 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -47,6 +47,7 @@ public:
Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const override;
std::unique_ptr<Tensor> modify(join_fun_t op, const CellValues &cellValues) const override;
std::unique_ptr<Tensor> add(const Tensor &arg) const override;
+ std::unique_ptr<Tensor> remove(const CellValues &cellAddresses) const override;
bool equals(const Tensor &arg) const override;
Tensor::UP clone() const override;
eval::TensorSpec toSpec() const override;
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp
new file mode 100644
index 00000000000..76af1e3b5fb
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.cpp
@@ -0,0 +1,33 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "sparse_tensor_remove.h"
+#include <vespa/eval/tensor/tensor_address_element_iterator.h>
+
+namespace vespalib::tensor {
+
+SparseTensorRemove::SparseTensorRemove(const eval::ValueType &type, Cells &&cells, Stash &&stash)
+ : _type(type),
+ _cells(std::move(cells)),
+ _stash(std::move(stash)),
+ _addressBuilder()
+{
+}
+
+SparseTensorRemove::~SparseTensorRemove() = default;
+
+void
+SparseTensorRemove::visit(const TensorAddress &address, double value)
+{
+ (void) value;
+ _addressBuilder.populate(_type, address);
+ auto addressRef = _addressBuilder.getAddressRef();
+ _cells.erase(addressRef);
+}
+
+std::unique_ptr<Tensor>
+SparseTensorRemove::build()
+{
+ return std::make_unique<SparseTensor>(std::move(_type), std::move(_cells), std::move(_stash));
+}
+
+}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h
new file mode 100644
index 00000000000..3d5905d8f41
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_remove.h
@@ -0,0 +1,32 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "sparse_tensor.h"
+#include "sparse_tensor_address_builder.h"
+#include <vespa/eval/tensor/tensor_visitor.h>
+
+namespace vespalib::tensor {
+
+/**
+ * This class handles a tensor remove operation on a sparse tensor.
+ *
+ * Creates a new tensor by removing the cells matching the cell addresses visited.
+ * The value associated with the address is ignored.
+ */
+class SparseTensorRemove : public TensorVisitor {
+private:
+ using Cells = SparseTensor::Cells;
+ eval::ValueType _type;
+ Cells _cells;
+ Stash _stash;
+ SparseTensorAddressBuilder _addressBuilder;
+
+public:
+ SparseTensorRemove(const eval::ValueType &type, Cells &&cells, Stash &&stash);
+ ~SparseTensorRemove();
+ void visit(const TensorAddress &address, double value) override;
+ std::unique_ptr<Tensor> build();
+};
+
+}
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index cdb9d90d3a3..4061ed9c115 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -49,6 +49,12 @@ public:
*/
virtual std::unique_ptr<Tensor> add(const Tensor &arg) const = 0;
+ /**
+ * Creates a new tensor by removing the cells matching the given cell addresses.
+ * The value associated with the address is ignored.
+ */
+ virtual std::unique_ptr<Tensor> remove(const CellValues &cellAddresses) const = 0;
+
virtual bool equals(const Tensor &arg) const = 0; // want to remove, but needed by document
virtual Tensor::UP clone() const = 0; // want to remove, but needed by document
virtual eval::TensorSpec toSpec() const = 0;
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 66fd2978a53..9df59a63873 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -89,4 +89,10 @@ WrappedSimpleTensor::add(const Tensor &) const
LOG_ABORT("should not be reached");
}
+std::unique_ptr<Tensor>
+WrappedSimpleTensor::remove(const CellValues &) const
+{
+ LOG_ABORT("should not be reached");
+}
+
} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
index 2d877b6fbbc..e7ffe7a755f 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
@@ -40,6 +40,7 @@ public:
Tensor::UP reduce(join_fun_t, const std::vector<vespalib::string> &) const override;
std::unique_ptr<Tensor> modify(join_fun_t, const CellValues &) const override;
std::unique_ptr<Tensor> add(const Tensor &arg) const override;
+ std::unique_ptr<Tensor> remove(const CellValues &) const override;
};
} // namespace vespalib::tensor