summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-08-28 13:05:54 +0000
committerGeir Storli <geirst@yahooinc.com>2023-08-28 13:13:08 +0000
commit82f69abca5404f13e5e4a49c716ac089f8d03602 (patch)
tree17249bf1fb48c32f623f054f78dcfbabd6344019
parentb85c362192bdfdafe44c7fe1257c463d5ad4340f (diff)
Handle tensor modify update with "create: true" for non-existing tensor.
-rw-r--r--document/src/tests/documentupdatetestcase.cpp12
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.h1
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp43
-rw-r--r--searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp9
-rw-r--r--searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp3
5 files changed, 54 insertions, 14 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index b225ca6677b..25815684b7e 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1038,13 +1038,23 @@ TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_can
.add({{"x", "c"}}, 6));
}
-TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor)
+TEST(DocumentUpdateTest, tensor_modify_update_is_ignored_when_applied_to_nonexisting_tensor)
{
TensorUpdateFixture f;
f.assertApplyUpdateNonExisting(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD,
f.makeTensor(f.spec().add({{"x", "b"}}, 5))));
}
+TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_is_applied_to_nonexisting_tensor)
+{
+ TensorUpdateFixture f;
+ f.assertApplyUpdateNonExisting(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD,
+ f.makeTensor(f.spec().add({{"x", "b"}}, 5)
+ .add({{"x", "c"}}, 6)), 0.0),
+ f.spec().add({{"x", "b"}}, 5)
+ .add({{"x", "c"}}, 6));
+}
+
TEST(DocumentUpdateTest, tensor_assign_update_can_be_roundtrip_serialized)
{
TensorUpdateFixture f;
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
index 52b27346ff8..7b025ea21a9 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h
@@ -32,6 +32,7 @@ public:
void accept(FieldValueVisitor &visitor) override;
void accept(ConstFieldValueVisitor &visitor) const override;
const DataType *getDataType() const override;
+ const TensorDataType& get_tensor_data_type() const { return _dataType; }
TensorFieldValue* clone() const override;
void print(std::ostream& out, bool verbose, const std::string& indent) const override;
void printXml(XmlOutputStream& out) const override;
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index ad1e3095269..198ee1c67c3 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -9,9 +9,11 @@
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
#include <vespa/document/util/serializableexceptions.h>
+#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/value.h>
-#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/value_codec.h>
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/xmlstream.h>
@@ -19,10 +21,11 @@
using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
-using vespalib::make_string;
-using vespalib::eval::ValueType;
using vespalib::eval::CellType;
using vespalib::eval::FastValueBuilderFactory;
+using vespalib::eval::Value;
+using vespalib::eval::ValueType;
+using vespalib::make_string;
using join_fun_t = double (*)(double, double);
@@ -145,13 +148,13 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const
}
}
-std::unique_ptr<vespalib::eval::Value>
-TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const
+std::unique_ptr<Value>
+TensorModifyUpdate::applyTo(const Value &tensor) const
{
return apply_to(tensor, FastValueBuilderFactory::get());
}
-std::unique_ptr<vespalib::eval::Value>
+std::unique_ptr<Value>
TensorModifyUpdate::apply_to(const Value &old_tensor,
const ValueBuilderFactory &factory) const
{
@@ -166,17 +169,33 @@ TensorModifyUpdate::apply_to(const Value &old_tensor,
return {};
}
+namespace {
+
+std::unique_ptr<Value>
+create_empty_tensor(const ValueType& type)
+{
+ const auto& factory = FastValueBuilderFactory::get();
+ vespalib::eval::TensorSpec empty_spec(type.to_spec());
+ return vespalib::eval::value_from_spec(empty_spec, factory);
+}
+
+}
+
bool
TensorModifyUpdate::applyTo(FieldValue& value) const
{
if (value.isA(FieldValue::Type::TENSOR)) {
TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
- auto oldTensor = tensorFieldValue.getAsTensorPtr();
- if (oldTensor) {
- auto newTensor = applyTo(*oldTensor);
- if (newTensor) {
- tensorFieldValue = std::move(newTensor);
- }
+ auto old_tensor = tensorFieldValue.getAsTensorPtr();
+ std::unique_ptr<Value> new_tensor;
+ if (old_tensor) {
+ new_tensor = applyTo(*old_tensor);
+ } else if (_default_cell_value.has_value()) {
+ auto empty_tensor = create_empty_tensor(tensorFieldValue.get_tensor_data_type().getTensorType());
+ new_tensor = applyTo(*empty_tensor);
+ }
+ if (new_tensor) {
+ tensorFieldValue = std::move(new_tensor);
}
} else {
vespalib::string err = make_string("Unable to perform a tensor modify update on a '%s' field value",
diff --git a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
index 892be2c874f..8c3ce4c5031 100644
--- a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
+++ b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
@@ -436,6 +436,15 @@ TEST_F("require that tensor modify update is applied",
f.assertTensor(TensorSpec(f.type).add({{"x", 0}}, 7).add({{"x", 1}}, 5));
}
+TEST_F("require that tensor modify update with 'create: true' is applied to non-existing tensor",
+ TensorFixture<DenseTensorAttribute>("tensor(x[2])", "dense_tensor"))
+{
+ f.applyValueUpdate(*f.attribute, 1,
+ std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD,
+ makeTensorFieldValue(TensorSpec("tensor(x{})").add({{"x", "1"}}, 3)), 0.0));
+ f.assertTensor(TensorSpec(f.type).add({{"x", 0}}, 0).add({{"x", 1}}, 3));
+}
+
TEST_F("require that tensor add update is applied",
TensorFixture<SerializedFastValueAttribute>("tensor(x{})", "sparse_tensor"))
{
diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
index ef9750b5f4c..2b78cdab966 100644
--- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
+++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
@@ -223,7 +223,8 @@ AttributeUpdater::handleUpdate(TensorAttribute &vec, uint32_t lid, const ValueUp
updateValue(vec, lid, assign.getValue());
}
} else if (op == ValueUpdate::TensorModify) {
- vec.update_tensor(lid, static_cast<const TensorModifyUpdate &>(upd), false);
+ const auto& modify = static_cast<const TensorModifyUpdate&>(upd);
+ vec.update_tensor(lid, modify, modify.get_default_cell_value().has_value());
} else if (op == ValueUpdate::TensorAdd) {
vec.update_tensor(lid, static_cast<const TensorAddUpdate &>(upd), true);
} else if (op == ValueUpdate::TensorRemove) {