summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-26 13:57:34 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-02-26 13:57:34 +0000
commite0114f9fa644e21d26946b0bc444ea21f66d291f (patch)
tree1476ae23d34693cae3b14b09d9b8a5f3b8838b22 /document
parente56fe867e5d4bc2b50219c2c5c10e4ea04fac024 (diff)
Verify during deserialize() that cells and address tensors are sparse.
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp18
-rw-r--r--document/src/vespa/document/base/testdocrepo.cpp3
-rw-r--r--document/src/vespa/document/update/tensor_add_update.cpp8
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp30
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp33
5 files changed, 69 insertions, 23 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 743babdf5e1..382d6f9a83b 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1016,6 +1016,24 @@ TEST(DocumentUpdateTest, tensor_modify_update_throws_on_non_tensor_field)
f.assertThrowOnNonTensorField(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor()));
}
+TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sparse)
+{
+ TensorUpdateFixture f("dense_tensor");
+ auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor
+ ASSERT_THROW(
+ f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))),
+ vespalib::IllegalStateException);
+}
+
+TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse)
+{
+ TensorUpdateFixture f("dense_tensor");
+ auto cellsTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense cells tensor
+ ASSERT_THROW(
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, std::move(cellsTensor))),
+ vespalib::IllegalStateException);
+}
+
void
assertDocumentUpdateFlag(bool createIfNonExistent, int value)
diff --git a/document/src/vespa/document/base/testdocrepo.cpp b/document/src/vespa/document/base/testdocrepo.cpp
index c8041d8c254..68a58ea1a86 100644
--- a/document/src/vespa/document/base/testdocrepo.cpp
+++ b/document/src/vespa/document/base/testdocrepo.cpp
@@ -52,7 +52,8 @@ DocumenttypesConfig TestDocRepo::getDefaultConfig() {
.addField("content", DataType::T_STRING)
.addField("rawarray", Array(DataType::T_RAW))
.addField("structarray", structarray_id)
- .addTensorField("sparse_tensor", "tensor(x{})"));
+ .addTensorField("sparse_tensor", "tensor(x{})")
+ .addTensorField("dense_tensor", "tensor(x[2])"));
builder.document(type2_id, "testdoctype2",
Struct("testdoctype2.header")
.addField("onlyinchild", DataType::T_INT),
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp
index b8e36922d8c..c35a8058133 100644
--- a/document/src/vespa/document/update/tensor_add_update.cpp
+++ b/document/src/vespa/document/update/tensor_add_update.cpp
@@ -99,8 +99,8 @@ TensorAddUpdate::applyTo(FieldValue& value) const
tensorFieldValue = std::move(newTensor);
}
} else {
- std::string err = make_string("Unable to perform a tensor add update on a '%s' field value",
- value.getClass().name());
+ vespalib::string err = make_string("Unable to perform a tensor add update on a '%s' field value",
+ value.getClass().name());
throw IllegalStateException(err, VESPA_STRLOC);
}
return true;
@@ -129,8 +129,8 @@ TensorAddUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type,
if (tensor->inherits(TensorFieldValue::classId)) {
_tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
} else {
- std::string err = make_string("Expected tensor field value, got a '%s' field value",
- tensor->getClass().name());
+ vespalib::string err = make_string("Expected tensor field value, got a '%s' field value",
+ tensor->getClass().name());
throw IllegalStateException(err, VESPA_STRLOC);
}
VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 64fc57d5287..37842b13cf4 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -159,9 +159,10 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const
std::unique_ptr<Tensor>
TensorModifyUpdate::applyTo(const Tensor &tensor) const
{
- auto &cellTensor = _tensor->getAsTensorPtr();
- if (cellTensor) {
- vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellTensor));
+ auto &cellsTensor = _tensor->getAsTensorPtr();
+ if (cellsTensor) {
+ // Cells tensor being sparse was validated during deserialize().
+ vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellsTensor));
return tensor.modify(getJoinFunction(_operation), cellValues);
}
return std::unique_ptr<Tensor>();
@@ -178,8 +179,8 @@ TensorModifyUpdate::applyTo(FieldValue& value) const
tensorFieldValue = std::move(newTensor);
}
} else {
- std::string err = make_string("Unable to perform a tensor modify update on a '%s' field value",
- value.getClass().name());
+ vespalib::string err = make_string("Unable to perform a tensor modify update on a '%s' field value",
+ value.getClass().name());
throw IllegalStateException(err, VESPA_STRLOC);
}
return true;
@@ -201,6 +202,20 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in
out << ")";
}
+namespace {
+
+void
+verifyCellsTensorIsSparse(const std::unique_ptr<Tensor> &cellsTensor)
+{
+ if (cellsTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor.get())) {
+ vespalib::string err = make_string("Expected cell values tensor to be sparse, but has type '%s'",
+ cellsTensor->type().to_spec().c_str());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+}
+
+}
+
void
TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream & stream)
{
@@ -217,12 +232,13 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty
if (tensor->inherits(TensorFieldValue::classId)) {
_tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
} else {
- std::string err = make_string("Expected tensor field value, got a '%s' field value",
- tensor->getClass().name());
+ vespalib::string err = make_string("Expected tensor field value, got a '%s' field value",
+ tensor->getClass().name());
throw IllegalStateException(err, VESPA_STRLOC);
}
VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
deserializer.read(*_tensor);
+ verifyCellsTensorIsSparse(_tensor->getAsTensorPtr());
}
TensorModifyUpdate*
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 7ae0604f3ca..c72d776fa9f 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -80,13 +80,9 @@ TensorRemoveUpdate::applyTo(const Tensor &tensor) const
{
auto &addressTensor = _tensor->getAsTensorPtr();
if (addressTensor) {
- if (const auto *sparseTensor = dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) {
- vespalib::tensor::CellValues cellAddresses(*sparseTensor);
- return tensor.remove(cellAddresses);
- } else {
- throw IllegalArgumentException(make_string("Expected address tensor to be sparse, but has type '%s'",
- addressTensor->type().to_spec().c_str()));
- }
+ // Address tensor being sparse was validated during deserialize().
+ vespalib::tensor::CellValues cellAddresses(static_cast<const vespalib::tensor::SparseTensor &>(*addressTensor));
+ return tensor.remove(cellAddresses);
}
return std::unique_ptr<Tensor>();
}
@@ -102,8 +98,8 @@ TensorRemoveUpdate::applyTo(FieldValue &value) const
tensorFieldValue = std::move(newTensor);
}
} else {
- std::string err = make_string("Unable to perform a tensor remove update on a '%s' field value",
- value.getClass().name());
+ vespalib::string err = make_string("Unable to perform a tensor remove update on a '%s' field value",
+ value.getClass().name());
throw IllegalStateException(err, VESPA_STRLOC);
}
return true;
@@ -125,6 +121,20 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in
out << ")";
}
+namespace {
+
+void
+verifyAddressTensorIsSparse(const std::unique_ptr<Tensor> &addressTensor)
+{
+ if (addressTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) {
+ vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'",
+ addressTensor->type().to_spec().c_str());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+}
+
+}
+
void
TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream)
{
@@ -132,12 +142,13 @@ TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty
if (tensor->inherits(TensorFieldValue::classId)) {
_tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
} else {
- std::string err = make_string("Expected tensor field value, got a '%s' field value",
- tensor->getClass().name());
+ vespalib::string err = make_string("Expected tensor field value, got a '%s' field value",
+ tensor->getClass().name());
throw IllegalStateException(err, VESPA_STRLOC);
}
VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
deserializer.read(*_tensor);
+ verifyAddressTensorIsSparse(_tensor->getAsTensorPtr());
}
TensorRemoveUpdate *