diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-26 13:57:34 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-02-26 13:57:34 +0000 |
commit | e0114f9fa644e21d26946b0bc444ea21f66d291f (patch) | |
tree | 1476ae23d34693cae3b14b09d9b8a5f3b8838b22 /document | |
parent | e56fe867e5d4bc2b50219c2c5c10e4ea04fac024 (diff) |
Verify during deserialize() that cells and address tensors are sparse.
Diffstat (limited to 'document')
5 files changed, 69 insertions, 23 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 743babdf5e1..382d6f9a83b 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1016,6 +1016,24 @@ TEST(DocumentUpdateTest, tensor_modify_update_throws_on_non_tensor_field) f.assertThrowOnNonTensorField(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor())); } +TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sparse) +{ + TensorUpdateFixture f("dense_tensor"); + auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor + ASSERT_THROW( + f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))), + vespalib::IllegalStateException); +} + +TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse) +{ + TensorUpdateFixture f("dense_tensor"); + auto cellsTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense cells tensor + ASSERT_THROW( + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, std::move(cellsTensor))), + vespalib::IllegalStateException); +} + void assertDocumentUpdateFlag(bool createIfNonExistent, int value) diff --git a/document/src/vespa/document/base/testdocrepo.cpp b/document/src/vespa/document/base/testdocrepo.cpp index c8041d8c254..68a58ea1a86 100644 --- a/document/src/vespa/document/base/testdocrepo.cpp +++ b/document/src/vespa/document/base/testdocrepo.cpp @@ -52,7 +52,8 @@ DocumenttypesConfig TestDocRepo::getDefaultConfig() { .addField("content", DataType::T_STRING) .addField("rawarray", Array(DataType::T_RAW)) .addField("structarray", structarray_id) - .addTensorField("sparse_tensor", "tensor(x{})")); + .addTensorField("sparse_tensor", "tensor(x{})") + .addTensorField("dense_tensor", "tensor(x[2])")); builder.document(type2_id, "testdoctype2", Struct("testdoctype2.header") .addField("onlyinchild", DataType::T_INT), diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index b8e36922d8c..c35a8058133 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -99,8 +99,8 @@ TensorAddUpdate::applyTo(FieldValue& value) const tensorFieldValue = std::move(newTensor); } } else { - std::string err = make_string("Unable to perform a tensor add update on a '%s' field value", - value.getClass().name()); + vespalib::string err = make_string("Unable to perform a tensor add update on a '%s' field value", + value.getClass().name()); throw IllegalStateException(err, VESPA_STRLOC); } return true; @@ -129,8 +129,8 @@ TensorAddUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, if (tensor->inherits(TensorFieldValue::classId)) { _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); } else { - std::string err = make_string("Expected tensor field value, got a '%s' field value", - tensor->getClass().name()); + vespalib::string err = make_string("Expected tensor field value, got a '%s' field value", + tensor->getClass().name()); throw IllegalStateException(err, VESPA_STRLOC); } VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion()); diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index 64fc57d5287..37842b13cf4 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -159,9 +159,10 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const std::unique_ptr<Tensor> TensorModifyUpdate::applyTo(const Tensor &tensor) const { - auto &cellTensor = _tensor->getAsTensorPtr(); - if (cellTensor) { - vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellTensor)); + auto &cellsTensor = _tensor->getAsTensorPtr(); + if (cellsTensor) { + // Cells tensor being sparse was validated during deserialize(). + vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellsTensor)); return tensor.modify(getJoinFunction(_operation), cellValues); } return std::unique_ptr<Tensor>(); @@ -178,8 +179,8 @@ TensorModifyUpdate::applyTo(FieldValue& value) const tensorFieldValue = std::move(newTensor); } } else { - std::string err = make_string("Unable to perform a tensor modify update on a '%s' field value", - value.getClass().name()); + vespalib::string err = make_string("Unable to perform a tensor modify update on a '%s' field value", + value.getClass().name()); throw IllegalStateException(err, VESPA_STRLOC); } return true; @@ -201,6 +202,20 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in out << ")"; } +namespace { + +void +verifyCellsTensorIsSparse(const std::unique_ptr<Tensor> &cellsTensor) +{ + if (cellsTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor.get())) { + vespalib::string err = make_string("Expected cell values tensor to be sparse, but has type '%s'", + cellsTensor->type().to_spec().c_str()); + throw IllegalStateException(err, VESPA_STRLOC); + } +} + +} + void TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream & stream) { @@ -217,12 +232,13 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty if (tensor->inherits(TensorFieldValue::classId)) { _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); } else { - std::string err = make_string("Expected tensor field value, got a '%s' field value", - tensor->getClass().name()); + vespalib::string err = make_string("Expected tensor field value, got a '%s' field value", + tensor->getClass().name()); throw IllegalStateException(err, VESPA_STRLOC); } VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion()); deserializer.read(*_tensor); + verifyCellsTensorIsSparse(_tensor->getAsTensorPtr()); } TensorModifyUpdate* diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 7ae0604f3ca..c72d776fa9f 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -80,13 +80,9 @@ TensorRemoveUpdate::applyTo(const Tensor &tensor) const { auto &addressTensor = _tensor->getAsTensorPtr(); if (addressTensor) { - if (const auto *sparseTensor = dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) { - vespalib::tensor::CellValues cellAddresses(*sparseTensor); - return tensor.remove(cellAddresses); - } else { - throw IllegalArgumentException(make_string("Expected address tensor to be sparse, but has type '%s'", - addressTensor->type().to_spec().c_str())); - } + // Address tensor being sparse was validated during deserialize(). + vespalib::tensor::CellValues cellAddresses(static_cast<const vespalib::tensor::SparseTensor &>(*addressTensor)); + return tensor.remove(cellAddresses); } return std::unique_ptr<Tensor>(); } @@ -102,8 +98,8 @@ TensorRemoveUpdate::applyTo(FieldValue &value) const tensorFieldValue = std::move(newTensor); } } else { - std::string err = make_string("Unable to perform a tensor remove update on a '%s' field value", - value.getClass().name()); + vespalib::string err = make_string("Unable to perform a tensor remove update on a '%s' field value", + value.getClass().name()); throw IllegalStateException(err, VESPA_STRLOC); } return true; @@ -125,6 +121,20 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in out << ")"; } +namespace { + +void +verifyAddressTensorIsSparse(const std::unique_ptr<Tensor> &addressTensor) +{ + if (addressTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) { + vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'", + addressTensor->type().to_spec().c_str()); + throw IllegalStateException(err, VESPA_STRLOC); + } +} + +} + void TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) { @@ -132,12 +142,13 @@ TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty if (tensor->inherits(TensorFieldValue::classId)) { _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); } else { - std::string err = make_string("Expected tensor field value, got a '%s' field value", - tensor->getClass().name()); + vespalib::string err = make_string("Expected tensor field value, got a '%s' field value", + tensor->getClass().name()); throw IllegalStateException(err, VESPA_STRLOC); } VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion()); deserializer.read(*_tensor); + verifyAddressTensorIsSparse(_tensor->getAsTensorPtr()); } TensorRemoveUpdate * |