diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-02-05 15:06:53 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-02-06 11:45:29 +0000 |
commit | 64359a8b55d9cfb24b2c39a890758a4f1ae8dda5 (patch) | |
tree | 5966f5969da7c87150050447193eea8210bc55ac /document | |
parent | e310eff84716e8ff7b681d04d22ac432f395b3e8 (diff) |
more robust tensor update
Diffstat (limited to 'document')
6 files changed, 69 insertions, 6 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 5543cb48ba4..18001c35da5 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -872,6 +872,13 @@ struct TensorUpdateFixture { EXPECT_EQ(actTensor, expTensor); } + void assertTensorNull() { + auto field = getTensor(); + auto tensor_field = dynamic_cast<TensorFieldValue*>(field.get()); + ASSERT_TRUE(tensor_field); + EXPECT_TRUE(tensor_field->getAsTensorPtr().get() == nullptr); + } + void assertTensor(const TensorSpec &expSpec) { auto expTensor = makeTensor(expSpec); assertTensor(*expTensor); @@ -886,6 +893,19 @@ struct TensorUpdateFixture { assertTensor(expTensor); } + void assertApplyUpdateNonExisting(const ValueUpdate &update, + const TensorSpec &expTensor) { + applyUpdate(update); + assertDocumentUpdated(); + assertTensor(expTensor); + } + + void assertApplyUpdateNonExisting(const ValueUpdate &update) { + applyUpdate(update); + assertDocumentUpdated(); + assertTensorNull(); + } + template <typename ValueUpdateType> void assertRoundtripSerialize(const ValueUpdateType &valueUpdate) { testRoundtripSerialize(valueUpdate, tensorDataType); @@ -933,6 +953,16 @@ TEST(DocumentUpdateTest, tensor_add_update_can_be_applied) .add({{"x", "c"}}, 7)); } +TEST(DocumentUpdateTest, tensor_add_update_can_be_applied_to_nonexisting_tensor) +{ + TensorUpdateFixture f; + f.assertApplyUpdateNonExisting(TensorAddUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 5) + .add({{"x", "c"}}, 7))), + + f.spec().add({{"x", "b"}}, 5) + .add({{"x", "c"}}, 7)); +} + TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied) { TensorUpdateFixture f; @@ -944,6 +974,12 @@ TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied) f.spec().add({{"x", "a"}}, 2)); } +TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied_to_nonexisting_tensor) +{ + TensorUpdateFixture f; + f.assertApplyUpdateNonExisting(TensorRemoveUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 1)))); +} + TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied) { TensorUpdateFixture f; @@ -970,6 +1006,13 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied) .add({{"x", "b"}}, 15)); } +TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor) +{ + TensorUpdateFixture f; + f.assertApplyUpdateNonExisting(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, + f.makeTensor(f.spec().add({{"x", "b"}}, 5)))); +} + TEST(DocumentUpdateTest, tensor_assign_update_can_be_roundtrip_serialized) { TensorUpdateFixture f; diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp index 9b318a39f0a..56d7b6ab078 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp @@ -104,6 +104,19 @@ TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs) } +void +TensorFieldValue::make_empty_if_not_existing() +{ + if (!_tensor) { + TensorSpec empty_spec(_dataType.getTensorType().to_spec()); + auto empty_value = Engine::ref().from_spec(empty_spec); + auto tensor_ptr = dynamic_cast<Tensor*>(empty_value.get()); + assert(tensor_ptr != nullptr); + _tensor.reset(tensor_ptr); + empty_value.release(); + } +} + void TensorFieldValue::accept(FieldValueVisitor &visitor) diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h index 80b18be55a0..ea3f8dea9be 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h @@ -28,6 +28,8 @@ public: TensorFieldValue &operator=(const TensorFieldValue &rhs); TensorFieldValue &operator=(std::unique_ptr<vespalib::tensor::Tensor> rhs); + void make_empty_if_not_existing(); + virtual void accept(FieldValueVisitor &visitor) override; virtual void accept(ConstFieldValueVisitor &visitor) const override; virtual const DataType *getDataType() const override; diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index c35a8058133..2e5fa194c20 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -93,6 +93,7 @@ TensorAddUpdate::applyTo(FieldValue& value) const { if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); + tensorFieldValue.make_empty_if_not_existing(); auto &oldTensor = tensorFieldValue.getAsTensorPtr(); auto newTensor = applyTo(*oldTensor); if (newTensor) { diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index 0d821de8922..dfc7479e5cd 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -174,9 +174,11 @@ TensorModifyUpdate::applyTo(FieldValue& value) const if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); auto &oldTensor = tensorFieldValue.getAsTensorPtr(); - auto newTensor = applyTo(*oldTensor); - if (newTensor) { - tensorFieldValue = std::move(newTensor); + if (oldTensor) { + auto newTensor = applyTo(*oldTensor); + if (newTensor) { + tensorFieldValue = std::move(newTensor); + } } } else { vespalib::string err = make_string("Unable to perform a tensor modify update on a '%s' field value", diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index f3bf7da7a0b..91b4c0a6ca3 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -120,9 +120,11 @@ TensorRemoveUpdate::applyTo(FieldValue &value) const if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); auto &oldTensor = tensorFieldValue.getAsTensorPtr(); - auto newTensor = applyTo(*oldTensor); - if (newTensor) { - tensorFieldValue = std::move(newTensor); + if (oldTensor) { + auto newTensor = applyTo(*oldTensor); + if (newTensor) { + tensorFieldValue = std::move(newTensor); + } } } else { vespalib::string err = make_string("Unable to perform a tensor remove update on a '%s' field value", |