summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-25 08:58:15 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-02-25 08:58:15 +0000
commit3167155f4ab0beaef9436bacb0b8a6fdb7764dac (patch)
tree2be4f91dd6f9412186a4a187883ab495a6b11d89 /eval/src
parent531435ea9289334a543fae7cef4d56a4d4ef34fb (diff)
Support add operation on mixed tensors.
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp47
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp29
2 files changed, 74 insertions, 2 deletions
diff --git a/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp b/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp
index 2f36aba6e3a..4c92dc717a7 100644
--- a/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp
+++ b/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp
@@ -21,6 +21,15 @@ assertAdd(const TensorSpec &source, const TensorSpec &arg, const TensorSpec &exp
EXPECT_EQ(actual, expected);
}
+void
+assertNullTensor(const TensorSpec &source, const TensorSpec &arg)
+{
+ auto sourceTensor = makeTensor<Tensor>(source);
+ auto argTensor = makeTensor<Tensor>(arg);
+ auto resultTensor = sourceTensor->add(*argTensor);
+ EXPECT_FALSE(resultTensor);
+}
+
TEST(TensorAddTest, cells_can_be_added_to_a_sparse_tensor)
{
assertAdd(TensorSpec("tensor(x{},y{})")
@@ -35,4 +44,42 @@ TEST(TensorAddTest, cells_can_be_added_to_a_sparse_tensor)
.add({{"x","e"},{"y","f"}}, 7));
}
+TEST(TensorAddTest, cells_can_be_added_to_a_mixed_tensor)
+{
+ assertAdd(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 5),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","b"},{"y",0}}, 6)
+ .add({{"x","b"},{"y",1}}, 7)
+ .add({{"x","c"},{"y",0}}, 8)
+ .add({{"x","c"},{"y",1}}, 9),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 6)
+ .add({{"x","b"},{"y",1}}, 7)
+ .add({{"x","c"},{"y",0}}, 8)
+ .add({{"x","c"},{"y",1}}, 9));
+}
+
+TEST(TensorAddTest, cells_can_be_added_to_empty_mixed_tensor)
+{
+ assertAdd(TensorSpec("tensor(x{},y[2])"),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","b"},{"y",0}}, 6)
+ .add({{"x","b"},{"y",1}}, 7),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","b"},{"y",0}}, 6)
+ .add({{"x","b"},{"y",1}}, 7));
+}
+
+TEST(TensorAddTest, tensors_of_different_types_cannot_be_added_together)
+{
+ assertNullTensor(TensorSpec("tensor(x{},y[2])"), TensorSpec("tensor(x{},y{})"));
+ assertNullTensor(TensorSpec("tensor(x{},y[2])"), TensorSpec("tensor(x{},y[3])"));
+}
+
GTEST_MAIN_RUN_ALL_TESTS
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 9df59a63873..1268d6fa9cb 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -12,6 +12,9 @@ LOG_SETUP(".eval.tensor.wrapped_simple_tensor");
namespace vespalib::tensor {
+using eval::SimpleTensor;
+using eval::TensorSpec;
+
bool
WrappedSimpleTensor::equals(const Tensor &arg) const
{
@@ -84,9 +87,31 @@ WrappedSimpleTensor::modify(join_fun_t, const CellValues &) const
}
std::unique_ptr<Tensor>
-WrappedSimpleTensor::add(const Tensor &) const
+WrappedSimpleTensor::add(const Tensor &arg) const
{
- LOG_ABORT("should not be reached");
+ const auto *rhs = dynamic_cast<const WrappedSimpleTensor *>(&arg);
+ if (!rhs || type() != rhs->type()) {
+ return Tensor::UP();
+ }
+
+ TensorSpec oldTensor = toSpec();
+ TensorSpec argTensor = rhs->toSpec();
+ TensorSpec result(type().to_spec());
+ for (const auto &cell : oldTensor.cells()) {
+ auto argItr = argTensor.cells().find(cell.first);
+ if (argItr != argTensor.cells().end()) {
+ result.add(argItr->first, argItr->second);
+ } else {
+ result.add(cell.first, cell.second);
+ }
+ }
+ for (const auto &cell : argTensor.cells()) {
+ auto resultItr = result.cells().find(cell.first);
+ if (resultItr == result.cells().end()) {
+ result.add(cell.first, cell.second);
+ }
+ }
+ return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
std::unique_ptr<Tensor>