diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-08-28 13:05:54 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2023-08-28 13:13:08 +0000 |
commit | 82f69abca5404f13e5e4a49c716ac089f8d03602 (patch) | |
tree | 17249bf1fb48c32f623f054f78dcfbabd6344019 /document/src | |
parent | b85c362192bdfdafe44c7fe1257c463d5ad4340f (diff) |
Handle tensor modify update with "create: true" for non-existing tensor.
Diffstat (limited to 'document/src')
-rw-r--r-- | document/src/tests/documentupdatetestcase.cpp | 12 | ||||
-rw-r--r-- | document/src/vespa/document/fieldvalue/tensorfieldvalue.h | 1 | ||||
-rw-r--r-- | document/src/vespa/document/update/tensor_modify_update.cpp | 43 |
3 files changed, 43 insertions, 13 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index b225ca6677b..25815684b7e 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1038,13 +1038,23 @@ TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_can .add({{"x", "c"}}, 6)); } -TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor) +TEST(DocumentUpdateTest, tensor_modify_update_is_ignored_when_applied_to_nonexisting_tensor) { TensorUpdateFixture f; f.assertApplyUpdateNonExisting(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, f.makeTensor(f.spec().add({{"x", "b"}}, 5)))); } +TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_is_applied_to_nonexisting_tensor) +{ + TensorUpdateFixture f; + f.assertApplyUpdateNonExisting(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, + f.makeTensor(f.spec().add({{"x", "b"}}, 5) + .add({{"x", "c"}}, 6)), 0.0), + f.spec().add({{"x", "b"}}, 5) + .add({{"x", "c"}}, 6)); +} + TEST(DocumentUpdateTest, tensor_assign_update_can_be_roundtrip_serialized) { TensorUpdateFixture f; diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h index 52b27346ff8..7b025ea21a9 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h @@ -32,6 +32,7 @@ public: void accept(FieldValueVisitor &visitor) override; void accept(ConstFieldValueVisitor &visitor) const override; const DataType *getDataType() const override; + const TensorDataType& get_tensor_data_type() const { return _dataType; } TensorFieldValue* clone() const override; void print(std::ostream& out, bool verbose, const std::string& indent) const override; void printXml(XmlOutputStream& out) const override; diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index ad1e3095269..198ee1c67c3 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -9,9 +9,11 @@ #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> #include <vespa/document/util/serializableexceptions.h> +#include <vespa/eval/eval/fast_value.h> #include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/value.h> -#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/value_codec.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/xmlstream.h> @@ -19,10 +21,11 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; -using vespalib::make_string; -using vespalib::eval::ValueType; using vespalib::eval::CellType; using vespalib::eval::FastValueBuilderFactory; +using vespalib::eval::Value; +using vespalib::eval::ValueType; +using vespalib::make_string; using join_fun_t = double (*)(double, double); @@ -145,13 +148,13 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const } } -std::unique_ptr<vespalib::eval::Value> -TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const +std::unique_ptr<Value> +TensorModifyUpdate::applyTo(const Value &tensor) const { return apply_to(tensor, FastValueBuilderFactory::get()); } -std::unique_ptr<vespalib::eval::Value> +std::unique_ptr<Value> TensorModifyUpdate::apply_to(const Value &old_tensor, const ValueBuilderFactory &factory) const { @@ -166,17 +169,33 @@ TensorModifyUpdate::apply_to(const Value &old_tensor, return {}; } +namespace { + +std::unique_ptr<Value> +create_empty_tensor(const ValueType& type) +{ + const auto& factory = FastValueBuilderFactory::get(); + vespalib::eval::TensorSpec empty_spec(type.to_spec()); + return vespalib::eval::value_from_spec(empty_spec, factory); +} + +} + bool TensorModifyUpdate::applyTo(FieldValue& value) const { if (value.isA(FieldValue::Type::TENSOR)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); - auto oldTensor = tensorFieldValue.getAsTensorPtr(); - if (oldTensor) { - auto newTensor = applyTo(*oldTensor); - if (newTensor) { - tensorFieldValue = std::move(newTensor); - } + auto old_tensor = tensorFieldValue.getAsTensorPtr(); + std::unique_ptr<Value> new_tensor; + if (old_tensor) { + new_tensor = applyTo(*old_tensor); + } else if (_default_cell_value.has_value()) { + auto empty_tensor = create_empty_tensor(tensorFieldValue.get_tensor_data_type().getTensorType()); + new_tensor = applyTo(*empty_tensor); + } + if (new_tensor) { + tensorFieldValue = std::move(new_tensor); } } else { vespalib::string err = make_string("Unable to perform a tensor modify update on a '%s' field value", |