summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-25 14:38:30 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-02-25 14:38:30 +0000
commit0e1c348bb3e27fc762806f7dbc444474e9036615 (patch)
tree7de8fc50b826c4a5978110e5cc2865c644624731 /eval
parent3148c6dc8e5d7911ccf0bbb533edaa4ceb3b7c5d (diff)
Support remove operation on mixed tensors.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp49
-rw-r--r--eval/src/vespa/eval/tensor/cell_values.h4
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h2
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp36
4 files changed, 86 insertions, 5 deletions
diff --git a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
index cb28019c4ee..df21b46691d 100644
--- a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
+++ b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
@@ -43,4 +43,53 @@ TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_sparse_tensor)
TensorSpec("tensor(x{},y{})"));
}
+TEST(TensorRemoveTest, cells_can_be_removed_from_a_mixed_tensor)
+{
+ assertRemove(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 5),
+ TensorSpec("tensor(x{})")
+ .add({{"x","b"}}, 1)
+ .add({{"x","c"}}, 1),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3));
+
+ assertRemove(TensorSpec("tensor(x{},y{},z[2])")
+ .add({{"x","a"},{"y","c"},{"z",0}}, 2)
+ .add({{"x","a"},{"y","c"},{"z",1}}, 3)
+ .add({{"x","b"},{"y","c"},{"z",0}}, 4)
+ .add({{"x","b"},{"y","c"},{"z",1}}, 5),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","b"},{"y","c"}}, 1)
+ .add({{"x","c"},{"y","c"}}, 1),
+ TensorSpec("tensor(x{},y{},z[2])")
+ .add({{"x","a"},{"y","c"},{"z",0}}, 2)
+ .add({{"x","a"},{"y","c"},{"z",1}}, 3));
+
+ assertRemove(TensorSpec("tensor(x{},y[1],z[2])")
+ .add({{"x","a"},{"y",0},{"z",0}}, 2)
+ .add({{"x","a"},{"y",0},{"z",1}}, 3)
+ .add({{"x","b"},{"y",0},{"z",0}}, 4)
+ .add({{"x","b"},{"y",0},{"z",1}}, 5),
+ TensorSpec("tensor(x{})")
+ .add({{"x","b"}}, 1)
+ .add({{"x","c"}}, 1),
+ TensorSpec("tensor(x{},y[1],z[2])")
+ .add({{"x","a"},{"y",0},{"z",0}}, 2)
+ .add({{"x","a"},{"y",0},{"z",1}}, 3));
+}
+
+TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_mixed_tensor)
+{
+ assertRemove(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3),
+ TensorSpec("tensor(x{})")
+ .add({{"x","a"}}, 1),
+ TensorSpec("tensor(x{},y[2])"));
+}
+
GTEST_MAIN_RUN_ALL_TESTS
diff --git a/eval/src/vespa/eval/tensor/cell_values.h b/eval/src/vespa/eval/tensor/cell_values.h
index dabdcde3294..4b8fd33376a 100644
--- a/eval/src/vespa/eval/tensor/cell_values.h
+++ b/eval/src/vespa/eval/tensor/cell_values.h
@@ -23,6 +23,10 @@ public:
void accept(TensorVisitor &visitor) const {
_tensor.accept(visitor);
}
+
+ eval::TensorSpec toSpec() const {
+ return _tensor.toSpec();
+ }
};
}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index 7eebff1f010..c182c09c6b0 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -2,10 +2,10 @@
#pragma once
+#include "sparse_tensor_address_ref.h"
#include <vespa/eval/tensor/cell_function.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/tensor_address.h>
-#include "sparse_tensor_address_ref.h"
#include <vespa/eval/tensor/types.h>
#include <vespa/vespalib/stllike/hash_map.h>
#include <vespa/vespalib/stllike/string.h>
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 1268d6fa9cb..9d451c4639d 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -1,8 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include "wrapped_simple_tensor.h"
+#include "cell_values.h"
#include "tensor_address_builder.h"
#include "tensor_visitor.h"
+#include "wrapped_simple_tensor.h"
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -114,10 +115,37 @@ WrappedSimpleTensor::add(const Tensor &arg) const
return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
+namespace {
+
+TensorSpec::Address
+extractMappedDimensions(const TensorSpec::Address &address)
+{
+ TensorSpec::Address result;
+ for (const auto &elem : address) {
+ if (elem.second.is_mapped()) {
+ result.emplace(elem);
+ }
+ }
+ return result;
+}
+
+}
+
std::unique_ptr<Tensor>
-WrappedSimpleTensor::remove(const CellValues &) const
+WrappedSimpleTensor::remove(const CellValues &cellAddresses) const
{
- LOG_ABORT("should not be reached");
+ TensorSpec oldTensor = toSpec();
+ TensorSpec toRemove = cellAddresses.toSpec();
+ TensorSpec result(type().to_spec());
+
+ for (const auto &cell : oldTensor.cells()) {
+ TensorSpec::Address mappedAddress = extractMappedDimensions(cell.first);
+ auto itr = toRemove.cells().find(mappedAddress);
+ if (itr == toRemove.cells().end()) {
+ result.add(cell.first, cell.second);
+ }
+ }
+ return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
-} // namespace vespalib::tensor
+}