From 3167155f4ab0beaef9436bacb0b8a6fdb7764dac Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Mon, 25 Feb 2019 08:58:15 +0000 Subject: Support add operation on mixed tensors. --- .../tensor_add_operation_test.cpp | 47 ++++++++++++++++++++++ .../vespa/eval/tensor/wrapped_simple_tensor.cpp | 29 ++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) (limited to 'eval/src') diff --git a/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp b/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp index 2f36aba6e3a..4c92dc717a7 100644 --- a/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp +++ b/eval/src/tests/tensor/tensor_add_operation/tensor_add_operation_test.cpp @@ -21,6 +21,15 @@ assertAdd(const TensorSpec &source, const TensorSpec &arg, const TensorSpec &exp EXPECT_EQ(actual, expected); } +void +assertNullTensor(const TensorSpec &source, const TensorSpec &arg) +{ + auto sourceTensor = makeTensor(source); + auto argTensor = makeTensor(arg); + auto resultTensor = sourceTensor->add(*argTensor); + EXPECT_FALSE(resultTensor); +} + TEST(TensorAddTest, cells_can_be_added_to_a_sparse_tensor) { assertAdd(TensorSpec("tensor(x{},y{})") @@ -35,4 +44,42 @@ TEST(TensorAddTest, cells_can_be_added_to_a_sparse_tensor) .add({{"x","e"},{"y","f"}}, 7)); } +TEST(TensorAddTest, cells_can_be_added_to_a_mixed_tensor) +{ + assertAdd(TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 4) + .add({{"x","b"},{"y",1}}, 5), + TensorSpec("tensor(x{},y[2])") + .add({{"x","b"},{"y",0}}, 6) + .add({{"x","b"},{"y",1}}, 7) + .add({{"x","c"},{"y",0}}, 8) + .add({{"x","c"},{"y",1}}, 9), + TensorSpec("tensor(x{},y[2])") + .add({{"x","a"},{"y",0}}, 2) + .add({{"x","a"},{"y",1}}, 3) + .add({{"x","b"},{"y",0}}, 6) + .add({{"x","b"},{"y",1}}, 7) + .add({{"x","c"},{"y",0}}, 8) + .add({{"x","c"},{"y",1}}, 9)); +} + +TEST(TensorAddTest, cells_can_be_added_to_empty_mixed_tensor) +{ + assertAdd(TensorSpec("tensor(x{},y[2])"), + TensorSpec("tensor(x{},y[2])") + .add({{"x","b"},{"y",0}}, 6) + .add({{"x","b"},{"y",1}}, 7), + TensorSpec("tensor(x{},y[2])") + .add({{"x","b"},{"y",0}}, 6) + .add({{"x","b"},{"y",1}}, 7)); +} + +TEST(TensorAddTest, tensors_of_different_types_cannot_be_added_together) +{ + assertNullTensor(TensorSpec("tensor(x{},y[2])"), TensorSpec("tensor(x{},y{})")); + assertNullTensor(TensorSpec("tensor(x{},y[2])"), TensorSpec("tensor(x{},y[3])")); +} + GTEST_MAIN_RUN_ALL_TESTS diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 9df59a63873..1268d6fa9cb 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -12,6 +12,9 @@ LOG_SETUP(".eval.tensor.wrapped_simple_tensor"); namespace vespalib::tensor { +using eval::SimpleTensor; +using eval::TensorSpec; + bool WrappedSimpleTensor::equals(const Tensor &arg) const { @@ -84,9 +87,31 @@ WrappedSimpleTensor::modify(join_fun_t, const CellValues &) const } std::unique_ptr -WrappedSimpleTensor::add(const Tensor &) const +WrappedSimpleTensor::add(const Tensor &arg) const { - LOG_ABORT("should not be reached"); + const auto *rhs = dynamic_cast(&arg); + if (!rhs || type() != rhs->type()) { + return Tensor::UP(); + } + + TensorSpec oldTensor = toSpec(); + TensorSpec argTensor = rhs->toSpec(); + TensorSpec result(type().to_spec()); + for (const auto &cell : oldTensor.cells()) { + auto argItr = argTensor.cells().find(cell.first); + if (argItr != argTensor.cells().end()) { + result.add(argItr->first, argItr->second); + } else { + result.add(cell.first, cell.second); + } + } + for (const auto &cell : argTensor.cells()) { + auto resultItr = result.cells().find(cell.first); + if (resultItr == result.cells().end()) { + result.add(cell.first, cell.second); + } + } + return std::make_unique(SimpleTensor::create(result)); } std::unique_ptr -- cgit v1.2.3