diff options
author | Geir Storli <geirst@verizonmedia.com> | 2020-11-17 12:51:08 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2020-11-17 12:57:49 +0000 |
commit | cf02c8777d8bff26b2f1cc73e342c38945b7c94c (patch) | |
tree | 92058c5a8aed850e26c622651437210a32246858 | |
parent | 4cd2c6a1d4d2ab7337678931271a815b535ce518 (diff) |
Add support for partial update remove operation where address is not fully specified.
-rw-r--r-- | eval/src/tests/tensor/partial_remove/partial_remove_test.cpp | 24 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/partial_update.cpp | 78 |
2 files changed, 93 insertions, 9 deletions
diff --git a/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp b/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp index 220eee0ba8f..e182fffa890 100644 --- a/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp +++ b/eval/src/tests/tensor/partial_remove/partial_remove_test.cpp @@ -116,4 +116,28 @@ TEST(PartialRemoveTest, partial_remove_returns_nullptr_on_invalid_inputs) { } } +void +expect_partial_remove(const TensorSpec& input, const TensorSpec& remove, const TensorSpec& exp) +{ + auto act = perform_partial_remove(input, remove); + EXPECT_EQ(exp, act); +} + +TEST(PartialRemoveTest, remove_where_address_is_not_fully_specified) { + auto input = TensorSpec("tensor(x{},y{})"). + add({{"x", "a"},{"y", "c"}}, 3.0). + add({{"x", "a"},{"y", "d"}}, 5.0). + add({{"x", "b"},{"y", "c"}}, 7.0); + + expect_partial_remove(input,TensorSpec("tensor(x{})").add({{"x", "a"}}, 1.0), + TensorSpec("tensor(x{},y{})").add({{"x", "b"},{"y", "c"}}, 7.0)); + + expect_partial_remove(input, TensorSpec("tensor(y{})").add({{"y", "c"}}, 1.0), + TensorSpec("tensor(x{},y{})").add({{"x", "a"},{"y", "d"}}, 5.0)); + + expect_partial_remove(input, TensorSpec("tensor(y{})").add({{"y", "d"}}, 1.0), + TensorSpec("tensor(x{},y{})").add({{"x", "a"},{"y", "c"}}, 3.0) + .add({{"x", "b"},{"y", "c"}}, 7.0)); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/tensor/partial_update.cpp b/eval/src/vespa/eval/tensor/partial_update.cpp index 014ffeb2666..fa15b2a38ae 100644 --- a/eval/src/vespa/eval/tensor/partial_update.cpp +++ b/eval/src/vespa/eval/tensor/partial_update.cpp @@ -298,31 +298,91 @@ struct PerformRemove { const ValueBuilderFactory &factory); }; +/** + * Calculates the indexes of where the mapped modifier dimensions are found in the mapped input dimensions. + * + * The modifier dimensions should be a subset or all of the input dimensions. + * An empty vector is returned on type mismatch. + */ +std::vector<size_t> +calc_mapped_dimension_indexes(const ValueType& input_type, + const ValueType& modifier_type) +{ + auto input_dims = input_type.mapped_dimensions(); + auto mod_dims = modifier_type.mapped_dimensions(); + if (mod_dims.size() > input_dims.size()) { + return {}; + } + std::vector<size_t> result(mod_dims.size()); + size_t j = 0; + for (size_t i = 0; i < mod_dims.size(); ++i) { + while ((j < input_dims.size()) && (input_dims[j] != mod_dims[i])) { + ++j; + } + if (j >= input_dims.size()) { + return {}; + } + result[i] = j; + } + return result; +} + +struct ModifierCoords { + + std::vector<const vespalib::stringref *> lookup_refs; + std::vector<size_t> lookup_view_dims; + + ModifierCoords(const SparseCoords& input_coords, + const std::vector<size_t>& input_dim_indexes, + const ValueType& modifier_type) + : lookup_refs(modifier_type.dimensions().size()), + lookup_view_dims(modifier_type.dimensions().size()) + { + assert(modifier_type.dimensions().size() == input_dim_indexes.size()); + for (size_t i = 0; i < input_dim_indexes.size(); ++i) { + // Setup the modifier dimensions to point to the matching input dimensions. + lookup_refs[i] = &input_coords.addr[input_dim_indexes[i]]; + lookup_view_dims[i] = i; + } + } + ~ModifierCoords() {} +}; + template <typename ICT> Value::UP PerformRemove::invoke(const Value &input, const Value &modifier, const ValueBuilderFactory &factory) { const ValueType &input_type = input.type(); const ValueType &modifier_type = modifier.type(); - if (input_type.mapped_dimensions() != modifier_type.dimensions()) { - LOG(error, "when removing cells from a tensor, mapped dimensions must be equal. " - "Got input type %s versus modifier type %s", - input_type.to_spec().c_str(), modifier_type.to_spec().c_str()); - return {}; - } const size_t num_mapped_in_input = input_type.count_mapped_dimensions(); if (num_mapped_in_input == 0) { - LOG(error, "cannot remove cells from a dense tensor of type %s", + LOG(error, "Cannot remove cells from a dense input tensor of type %s", input_type.to_spec().c_str()); return {}; } + if (modifier_type.count_indexed_dimensions() != 0) { + LOG(error, "Cannot remove cells using a modifier tensor of type %s", + modifier_type.to_spec().c_str()); + return {}; + } + auto input_dim_indexes = calc_mapped_dimension_indexes(input_type, modifier_type); + if (input_dim_indexes.empty()) { + LOG(error, "Tensor type mismatch when removing cells from a tensor. " + "Got input type %s versus modifier type %s", + input_type.to_spec().c_str(), modifier_type.to_spec().c_str()); + return {}; + } SparseCoords addrs(num_mapped_in_input); - auto modifier_view = modifier.index().create_view(addrs.lookup_view_dims); + ModifierCoords mod_coords(addrs, input_dim_indexes, modifier_type); + auto modifier_view = modifier.index().create_view(mod_coords.lookup_view_dims); const size_t expected_subspaces = input.index().size(); const size_t dsss = input_type.dense_subspace_size(); auto builder = factory.create_value_builder<ICT>(input_type, num_mapped_in_input, dsss, expected_subspaces); auto filter_by_modifier = [&] (const auto & lookup_refs, size_t) { - modifier_view->lookup(lookup_refs); + // The modifier dimensions are setup to point to the input dimensions address storage in ModifierCoords, + // so we don't need to use the lookup_refs argument. + (void) lookup_refs; + modifier_view->lookup(mod_coords.lookup_refs); size_t modifier_subspace_index; return !(modifier_view->next_result({}, modifier_subspace_index)); }; |