summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-02-05 15:06:53 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-02-06 11:45:29 +0000
commit64359a8b55d9cfb24b2c39a890758a4f1ae8dda5 (patch)
tree5966f5969da7c87150050447193eea8210bc55ac /document
parente310eff84716e8ff7b681d04d22ac432f395b3e8 (diff)
more robust tensor update
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp43
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp13
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.h2
-rw-r--r--document/src/vespa/document/update/tensor_add_update.cpp1
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp8
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp8
6 files changed, 69 insertions, 6 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 5543cb48ba4..18001c35da5 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -872,6 +872,13 @@ struct TensorUpdateFixture {
EXPECT_EQ(actTensor, expTensor);
}
+ void assertTensorNull() {
+ auto field = getTensor();
+ auto tensor_field = dynamic_cast<TensorFieldValue*>(field.get());
+ ASSERT_TRUE(tensor_field);
+ EXPECT_TRUE(tensor_field->getAsTensorPtr().get() == nullptr);
+ }
+
void assertTensor(const TensorSpec &expSpec) {
auto expTensor = makeTensor(expSpec);
assertTensor(*expTensor);
@@ -886,6 +893,19 @@ struct TensorUpdateFixture {
assertTensor(expTensor);
}
+ void assertApplyUpdateNonExisting(const ValueUpdate &update,
+ const TensorSpec &expTensor) {
+ applyUpdate(update);
+ assertDocumentUpdated();
+ assertTensor(expTensor);
+ }
+
+ void assertApplyUpdateNonExisting(const ValueUpdate &update) {
+ applyUpdate(update);
+ assertDocumentUpdated();
+ assertTensorNull();
+ }
+
template <typename ValueUpdateType>
void assertRoundtripSerialize(const ValueUpdateType &valueUpdate) {
testRoundtripSerialize(valueUpdate, tensorDataType);
@@ -933,6 +953,16 @@ TEST(DocumentUpdateTest, tensor_add_update_can_be_applied)
.add({{"x", "c"}}, 7));
}
+TEST(DocumentUpdateTest, tensor_add_update_can_be_applied_to_nonexisting_tensor)
+{
+ TensorUpdateFixture f;
+ f.assertApplyUpdateNonExisting(TensorAddUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 5)
+ .add({{"x", "c"}}, 7))),
+
+ f.spec().add({{"x", "b"}}, 5)
+ .add({{"x", "c"}}, 7));
+}
+
TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied)
{
TensorUpdateFixture f;
@@ -944,6 +974,12 @@ TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied)
f.spec().add({{"x", "a"}}, 2));
}
+TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied_to_nonexisting_tensor)
+{
+ TensorUpdateFixture f;
+ f.assertApplyUpdateNonExisting(TensorRemoveUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 1))));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied)
{
TensorUpdateFixture f;
@@ -970,6 +1006,13 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied)
.add({{"x", "b"}}, 15));
}
+TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor)
+{
+ TensorUpdateFixture f;
+ f.assertApplyUpdateNonExisting(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD,
+ f.makeTensor(f.spec().add({{"x", "b"}}, 5))));
+}
+
TEST(DocumentUpdateTest, tensor_assign_update_can_be_roundtrip_serialized)
{
TensorUpdateFixture f;
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
index 9b318a39f0a..56d7b6ab078 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
@@ -104,6 +104,19 @@ TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs)
}
+void
+TensorFieldValue::make_empty_if_not_existing()
+{
+ if (!_tensor) {
+ TensorSpec empty_spec(_dataType.getTensorType().to_spec());
+ auto empty_value = Engine::ref().from_spec(empty_spec);
+ auto tensor_ptr = dynamic_cast<Tensor*>(empty_value.get());
+ assert(tensor_ptr != nullptr);
+ _tensor.reset(tensor_ptr);
+ empty_value.release();
+ }
+}
+
void
TensorFieldValue::accept(FieldValueVisitor &visitor)
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
index 80b18be55a0..ea3f8dea9be 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
@@ -28,6 +28,8 @@ public:
TensorFieldValue &operator=(const TensorFieldValue &rhs);
TensorFieldValue &operator=(std::unique_ptr<vespalib::tensor::Tensor> rhs);
+ void make_empty_if_not_existing();
+
virtual void accept(FieldValueVisitor &visitor) override;
virtual void accept(ConstFieldValueVisitor &visitor) const override;
virtual const DataType *getDataType() const override;
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp
index c35a8058133..2e5fa194c20 100644
--- a/document/src/vespa/document/update/tensor_add_update.cpp
+++ b/document/src/vespa/document/update/tensor_add_update.cpp
@@ -93,6 +93,7 @@ TensorAddUpdate::applyTo(FieldValue& value) const
{
if (value.inherits(TensorFieldValue::classId)) {
TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
+ tensorFieldValue.make_empty_if_not_existing();
auto &oldTensor = tensorFieldValue.getAsTensorPtr();
auto newTensor = applyTo(*oldTensor);
if (newTensor) {
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 0d821de8922..dfc7479e5cd 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -174,9 +174,11 @@ TensorModifyUpdate::applyTo(FieldValue& value) const
if (value.inherits(TensorFieldValue::classId)) {
TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
auto &oldTensor = tensorFieldValue.getAsTensorPtr();
- auto newTensor = applyTo(*oldTensor);
- if (newTensor) {
- tensorFieldValue = std::move(newTensor);
+ if (oldTensor) {
+ auto newTensor = applyTo(*oldTensor);
+ if (newTensor) {
+ tensorFieldValue = std::move(newTensor);
+ }
}
} else {
vespalib::string err = make_string("Unable to perform a tensor modify update on a '%s' field value",
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index f3bf7da7a0b..91b4c0a6ca3 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -120,9 +120,11 @@ TensorRemoveUpdate::applyTo(FieldValue &value) const
if (value.inherits(TensorFieldValue::classId)) {
TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
auto &oldTensor = tensorFieldValue.getAsTensorPtr();
- auto newTensor = applyTo(*oldTensor);
- if (newTensor) {
- tensorFieldValue = std::move(newTensor);
+ if (oldTensor) {
+ auto newTensor = applyTo(*oldTensor);
+ if (newTensor) {
+ tensorFieldValue = std::move(newTensor);
+ }
}
} else {
vespalib::string err = make_string("Unable to perform a tensor remove update on a '%s' field value",