summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-25 15:01:42 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-02-25 15:01:42 +0000
commit4792d566c8373d4e48b540fe7cfb23bdd4bc5d10 (patch)
treec47389dfb82cdd55a1c922ee3182e6941829581c /eval
parent0e1c348bb3e27fc762806f7dbc444474e9036615 (diff)
Support modify operation on mixed tensors.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp30
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp36
2 files changed, 64 insertions, 2 deletions
diff --git a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp
index ff59b28d60c..0b68088f93b 100644
--- a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp
+++ b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp
@@ -61,6 +61,23 @@ TEST(TensorModifyTest, dense_tensors_can_be_modified)
.add({{"x",9},{"y",9}}, 11));
}
+TEST(TensorModifyTest, mixed_tensors_can_be_modified)
+{
+ checkUpdate(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 5),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","0"}}, 6)
+ .add({{"x","b"},{"y","1"}}, 7),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 6)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 7));
+}
+
TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells)
{
checkUpdate(TensorSpec("tensor(x{},y{})")
@@ -87,4 +104,17 @@ TEST(TensorModifyTest, dense_tensors_ignore_updates_to_out_of_range_cells)
.add({{"x",9},{"y",9}}, 11));
}
+TEST(TensorModifyTest, mixed_tensors_ignore_updates_to_missing_or_out_of_range_cells)
+{
+ checkUpdate(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","2"}}, 4)
+ .add({{"x","c"},{"y","0"}}, 5),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3));
+}
+
GTEST_MAIN_RUN_ALL_TESTS
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 9d451c4639d..a982a4b0fe1 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -81,10 +81,42 @@ WrappedSimpleTensor::reduce(join_fun_t, const std::vector<vespalib::string> &) c
LOG_ABORT("should not be reached");
}
+namespace {
+
+TensorSpec::Address
+convertToOnlyMappedDimensions(const TensorSpec::Address &address)
+{
+ TensorSpec::Address result;
+ for (const auto &elem : address) {
+ if (elem.second.is_indexed()) {
+ result.emplace(std::make_pair(elem.first,
+ TensorSpec::Label(vespalib::make_string("%zu", elem.second.index))));
+ } else {
+ result.emplace(elem);
+ }
+ }
+ return result;
+}
+
+}
+
std::unique_ptr<Tensor>
-WrappedSimpleTensor::modify(join_fun_t, const CellValues &) const
+WrappedSimpleTensor::modify(join_fun_t op, const CellValues &cellValues) const
{
- LOG_ABORT("should not be reached");
+ TensorSpec oldTensor = toSpec();
+ TensorSpec toModify = cellValues.toSpec();
+ TensorSpec result(type().to_spec());
+
+ for (const auto &cell : oldTensor.cells()) {
+ TensorSpec::Address mappedAddress = convertToOnlyMappedDimensions(cell.first);
+ auto itr = toModify.cells().find(mappedAddress);
+ if (itr != toModify.cells().end()) {
+ result.add(cell.first, op(cell.second, itr->second));
+ } else {
+ result.add(cell.first, cell.second);
+ }
+ }
+ return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
std::unique_ptr<Tensor>