diff options
author | Tor Egge <Tor.Egge@broadpark.no> | 2019-03-06 18:14:03 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@broadpark.no> | 2019-03-07 13:28:48 +0100 |
commit | 93c5cb44ba9f8fa36314b8a4d6d57b75422f8c29 (patch) | |
tree | 0aabf97b84256fe9c10d3d22de0b9ddcf73d9129 /document | |
parent | 99bcfb517bd0b57c24f81478c3767f1b8d369fb3 (diff) |
Check for assignable tensor type when setting tensor in TensorFieldValue.
Diffstat (limited to 'document')
7 files changed, 197 insertions, 25 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 295372a42ca..71bc545897b 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/document/base/testdocman.h> +#include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/fieldvalues.h> #include <vespa/document/repo/configbuilder.h> @@ -1050,7 +1051,7 @@ TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_spar auto cellsTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense cells tensor ASSERT_THROW( f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, std::move(cellsTensor))), - vespalib::IllegalStateException); + document::WrongTensorTypeException); } struct TensorUpdateSerializeFixture { diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp index 624cdeff04e..c573eef6147 100644 --- a/document/src/tests/serialization/vespadocumentserializer_test.cpp +++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp @@ -10,6 +10,7 @@ #include <vespa/document/datatype/documenttype.h> #include <vespa/document/datatype/weightedsetdatatype.h> #include <vespa/document/datatype/mapdatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/annotationreferencefieldvalue.h> #include <vespa/document/fieldvalue/arrayfieldvalue.h> #include <vespa/document/fieldvalue/bytefieldvalue.h> @@ -139,6 +140,11 @@ template <> ReferenceFieldValue newFieldValue(const ReferenceFieldValue& value) *value.getDataType())); } +template <> TensorFieldValue newFieldValue(const TensorFieldValue& value) { + return TensorFieldValue(dynamic_cast<const TensorDataType&>( + *value.getDataType())); +} + template<typename T> void testDeserializeAndClone(const T& value, const nbostream &stream, bool checkEqual=true) { T read_value = newFieldValue(value); @@ -833,13 +839,14 @@ createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { TEST("Require that tensors can be serialized") { - TensorFieldValue noTensorValue; - TensorFieldValue emptyTensorValue; - TensorFieldValue twoCellsTwoDimsValue; + TensorDataType xySparseTensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{},y{})")); + TensorFieldValue noTensorValue(xySparseTensorDataType); + TensorFieldValue emptyTensorValue(xySparseTensorDataType); + TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType); nbostream stream; serializeAndDeserialize(noTensorValue, stream); stream.clear(); - emptyTensorValue = createTensor({}, {}); + emptyTensorValue = createTensor({}, {"x", "y"}); serializeAndDeserialize(emptyTensorValue, stream); stream.clear(); twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3}, @@ -855,19 +862,25 @@ TEST("Require that tensors can be serialized") const int tensor_doc_type_id = 321; const string tensor_field_name = "my_tensor"; -DocumenttypesConfig getTensorDocTypesConfig() { +DocumenttypesConfig getTensorDocTypesConfig(const vespalib::string &tensorType) { DocumenttypesConfigBuilderHelper builder; builder.document(tensor_doc_type_id, "my_type", Struct("my_type.header"), Struct("my_type.body") - .addField(tensor_field_name, DataType::T_TENSOR)); + .addTensorField(tensor_field_name, tensorType)); return builder.config(); } +DocumenttypesConfig getTensorDocTypesConfig() { + return getTensorDocTypesConfig("tensor(dimX{},dimY{})"); +} + const DocumentTypeRepo tensor_doc_repo(getTensorDocTypesConfig()); const FixedTypeRepo tensor_repo(tensor_doc_repo, *tensor_doc_repo.getDocumentType(doc_type_id)); +const DocumentTypeRepo tensor_doc_repo1(getTensorDocTypesConfig("tensor(dimX{})")); + void serializeToFile(TensorFieldValue &value, const string &file_name) { const DocumentType *type = tensor_doc_repo.getDocumentType(tensor_doc_type_id); @@ -881,7 +894,8 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) { void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) { const string data_dir = TEST_PATH("../../test/resources/tensor/"); - TensorFieldValue value; + TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type()); + TensorFieldValue value(valueType); if (tensor) { value = tensor->clone(); } @@ -901,6 +915,92 @@ TEST("Require that tensor deserialization matches Java") { { "dimX", "dimY" })); } +struct TensorDocFixture { + const DocumentTypeRepo &_docTypeRepo; + const DocumentType *_docType; + std::unique_ptr<Tensor> _tensor; + Document _doc; + vespalib::nbostream _blob; + + TensorDocFixture(const DocumentTypeRepo &docTypeRepo, + std::unique_ptr<Tensor> tensor); + ~TensorDocFixture(); +}; + +TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo, + std::unique_ptr<Tensor> tensor) + : _docTypeRepo(docTypeRepo), + _docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)), + _tensor(std::move(tensor)), + _doc(*_docType, DocumentId("id:test:my_type::foo")), + _blob() +{ + auto fv = _doc.getField(tensor_field_name).createValue(); + dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone(); + _doc.setValue(tensor_field_name, *fv); + _doc.serialize(_blob); +} + +TensorDocFixture::~TensorDocFixture() = default; + +struct DeserializedTensorDoc +{ + std::unique_ptr<Document> _doc; + std::unique_ptr<FieldValue> _fieldValue; + + DeserializedTensorDoc(); + ~DeserializedTensorDoc(); + + void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob); + const Tensor *getTensor() const; +}; + +DeserializedTensorDoc::DeserializedTensorDoc() + : _doc(), + _fieldValue() +{ +} + +DeserializedTensorDoc::~DeserializedTensorDoc() = default; + +void +DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob) +{ + vespalib::nbostream wrapStream(blob.peek(), blob.size()); + _doc = std::make_unique<Document>(docTypeRepo, wrapStream, nullptr); + _fieldValue = _doc->getValue(tensor_field_name); +} + +const Tensor * +DeserializedTensorDoc::getTensor() const +{ + return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr().get(); +} + +TEST("Require that wrong tensor type hides tensor") +{ + TensorDocFixture f(tensor_doc_repo, + createTensor({ {{{"dimX", "a"},{"dimY", "bb"}}, 2.0 }, + {{{"dimX", "ccc"},{"dimY", "dddd"}}, 3.0}, + {{{"dimX", "e"},{"dimY","ff"}}, 5.0} }, + { "dimX", "dimY" })); + TensorDocFixture f1(tensor_doc_repo1, + createTensor({ {{{"dimX", "a"}}, 20.0 }, + {{{"dimX", "ccc"}}, 30.0} }, + { "dimX" })); + DeserializedTensorDoc doc; + doc.setup(tensor_doc_repo, f._blob); + EXPECT_TRUE(doc.getTensor() != nullptr); + EXPECT_TRUE(doc.getTensor()->equals(*f._tensor)); + doc.setup(tensor_doc_repo, f1._blob); + EXPECT_TRUE(doc.getTensor() == nullptr); + doc.setup(tensor_doc_repo1, f._blob); + EXPECT_TRUE(doc.getTensor() == nullptr); + doc.setup(tensor_doc_repo1, f1._blob); + EXPECT_TRUE(doc.getTensor() != nullptr); + EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor)); +} + struct RefFixture { const DocumentType* ref_doc_type{doc_repo.getDocumentType(doc_with_ref_type_id)}; FixedTypeRepo fixed_repo{doc_repo, *ref_doc_type}; diff --git a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp index 7a929ae26b4..08e615bccaf 100644 --- a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp +++ b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp @@ -4,26 +4,42 @@ #include <vespa/log/log.h> LOG_SETUP("fieldvalue_test"); +#include <vespa/document/base/exceptions.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/types.h> #include <vespa/eval/tensor/default_tensor.h> #include <vespa/eval/tensor/tensor_factory.h> +#include <vespa/eval/tensor/test/test_utils.h> #include <vespa/vespalib/testkit/testapp.h> using namespace document; using namespace vespalib::tensor; +using vespalib::eval::TensorSpec; +using vespalib::eval::ValueType; +using vespalib::tensor::test::makeTensor; namespace { +TensorDataType xSparseTensorDataType(ValueType::from_spec("tensor(x{})")); +TensorDataType xySparseTensorDataType(ValueType::from_spec("tensor(x{},y{})")); + Tensor::UP createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { vespalib::tensor::DefaultTensor::builder builder; return vespalib::tensor::TensorFactory::create(cells, dimensions, builder); } +std::unique_ptr<Tensor> +makeSimpleTensor() +{ + return makeTensor<Tensor>(TensorSpec("tensor(x{},y{})"). + add({{"x", "4"}, {"y", "5"}}, 7)); +} + FieldValue::UP clone(FieldValue &fv) { auto ret = FieldValue::UP(fv.clone()); EXPECT_NOT_EQUAL(ret.get(), &fv); @@ -35,10 +51,10 @@ FieldValue::UP clone(FieldValue &fv) { } TEST("require that TensorFieldValue can be assigned tensors and cloned") { - TensorFieldValue noTensorValue; - TensorFieldValue emptyTensorValue; - TensorFieldValue twoCellsTwoDimsValue; - emptyTensorValue = createTensor({}, {}); + TensorFieldValue noTensorValue(xySparseTensorDataType); + TensorFieldValue emptyTensorValue(xySparseTensorDataType); + TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType); + emptyTensorValue = createTensor({}, {"x", "y"}); twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3}, {{{"x", "4"}, {"y", "5"}}, 7} }, {"x", "y"}); @@ -57,7 +73,7 @@ TEST("require that TensorFieldValue can be assigned tensors and cloned") { EXPECT_NOT_EQUAL(*emptyClone, *twoClone); EXPECT_NOT_EQUAL(*twoClone, *noneClone); EXPECT_NOT_EQUAL(*twoClone, *emptyClone); - TensorFieldValue twoCellsTwoDimsValue2; + TensorFieldValue twoCellsTwoDimsValue2(xySparseTensorDataType); twoCellsTwoDimsValue2 = createTensor({ {{{"y", "3"}}, 3}, {{{"x", "4"}, {"y", "5"}}, 7} }, @@ -69,11 +85,36 @@ TEST("require that TensorFieldValue can be assigned tensors and cloned") { TEST("require that TensorFieldValue::toString works") { - TensorFieldValue tensorFieldValue; + TensorFieldValue tensorFieldValue(xSparseTensorDataType); EXPECT_EQUAL("{TensorFieldValue: null}", tensorFieldValue.toString()); tensorFieldValue = createTensor({{{{"x","a"}}, 3}}, {"x"}); EXPECT_EQUAL("{TensorFieldValue: {\"dimensions\":[\"x\"],\"cells\":[{\"address\":{\"x\":\"a\"},\"value\":3}]}}", tensorFieldValue.toString()); } +TEST("require that wrong tensor type for special case assign throws exception") +{ + TensorFieldValue tensorFieldValue(xSparseTensorDataType); + EXPECT_EXCEPTION(tensorFieldValue = makeSimpleTensor(), + document::WrongTensorTypeException, + "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'"); +} + +TEST("require that wrong tensor type for copy assign throws exception") +{ + TensorFieldValue tensorFieldValue(xSparseTensorDataType); + TensorFieldValue simpleTensorFieldValue(xySparseTensorDataType); + simpleTensorFieldValue = makeSimpleTensor(); + EXPECT_EXCEPTION(tensorFieldValue = simpleTensorFieldValue, + document::WrongTensorTypeException, + "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'"); +} + +TEST("require that wrong tensor type for assignDeserialized throws exception") +{ + TensorFieldValue tensorFieldValue(xSparseTensorDataType); + EXPECT_EXCEPTION(tensorFieldValue.assignDeserialized(makeSimpleTensor()), + document::WrongTensorTypeException, + "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'"); +} TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/document/src/vespa/document/base/exceptions.cpp b/document/src/vespa/document/base/exceptions.cpp index 1dd891fc500..22c772cd987 100644 --- a/document/src/vespa/document/base/exceptions.cpp +++ b/document/src/vespa/document/base/exceptions.cpp @@ -104,4 +104,6 @@ DocumentTypeNotFoundException::~DocumentTypeNotFoundException() throw() { } +VESPA_IMPLEMENT_EXCEPTION(WrongTensorTypeException, vespalib::Exception); + } diff --git a/document/src/vespa/document/base/exceptions.h b/document/src/vespa/document/base/exceptions.h index 49ddecc8441..f79d4f4da0e 100644 --- a/document/src/vespa/document/base/exceptions.h +++ b/document/src/vespa/document/base/exceptions.h @@ -138,5 +138,6 @@ public: VESPA_DEFINE_EXCEPTION_SPINE(FieldNotFoundException); }; -} +VESPA_DEFINE_EXCEPTION(WrongTensorTypeException, vespalib::Exception); +} diff --git a/document/src/vespa/document/fieldvalue/structfieldvalue.cpp b/document/src/vespa/document/fieldvalue/structfieldvalue.cpp index f2e2896d9c3..ac37970213c 100644 --- a/document/src/vespa/document/fieldvalue/structfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/structfieldvalue.cpp @@ -176,8 +176,13 @@ void createFV(FieldValue & value, const DocumentTypeRepo & repo, nbostream & stream, const DocumentType & doc_type, uint32_t version) { FixedTypeRepo frepo(repo, doc_type); - VespaDocumentDeserializer deserializer(frepo, stream, version); - deserializer.read(value); + try { + VespaDocumentDeserializer deserializer(frepo, stream, version); + deserializer.read(value); + } catch (WrongTensorTypeException &) { + // A tensor field will appear to have no tensor if the stored tensor + // cannot be assigned to the tensor field. + } } } diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp index 99ee030942f..399720e2354 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensorfieldvalue.h" +#include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/vespalib/util/xmlstream.h> #include <vespa/eval/tensor/tensor.h> @@ -12,6 +13,7 @@ using vespalib::slime::JsonFormat; using vespalib::tensor::Tensor; using vespalib::tensor::SlimeBinaryFormat; +using vespalib::eval::ValueType; using namespace vespalib::xml; namespace document { @@ -20,6 +22,13 @@ namespace { TensorDataType emptyTensorDataType; +vespalib::string makeWrongTensorTypeMsg(const ValueType &fieldTensorType, const ValueType &tensorType) +{ + return vespalib::make_string("Field tensor type is '%s' but tensor type is '%s'", + fieldTensorType.to_spec().c_str(), + tensorType.to_spec().c_str()); +} + } TensorFieldValue::TensorFieldValue() @@ -66,12 +75,17 @@ TensorFieldValue & TensorFieldValue::operator=(const TensorFieldValue &rhs) { if (this != &rhs) { - if (rhs._tensor) { - _tensor = rhs._tensor->clone(); + if (&_dataType == &rhs._dataType || !rhs._tensor || + _dataType.isAssignableType(rhs._tensor->type())) { + if (rhs._tensor) { + _tensor = rhs._tensor->clone(); + } else { + _tensor.reset(); + } + _altered = true; } else { - _tensor.reset(); + throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs._tensor->type()), VESPA_STRLOC); } - _altered = true; } return *this; } @@ -80,8 +94,12 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs) TensorFieldValue & TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs) { - _tensor = std::move(rhs); - _altered = true; + if (!rhs || _dataType.isAssignableType(rhs->type())) { + _tensor = std::move(rhs); + _altered = true; + } else { + throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs->type()), VESPA_STRLOC); + } return *this; } @@ -165,8 +183,12 @@ TensorFieldValue::assign(const FieldValue &value) void TensorFieldValue::assignDeserialized(std::unique_ptr<Tensor> rhs) { - _tensor = std::move(rhs); - _altered = false; // Serialized form already exists + if (!rhs || _dataType.isAssignableType(rhs->type())) { + _tensor = std::move(rhs); + _altered = false; // Serialized form already exists + } else { + throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs->type()), VESPA_STRLOC); + } } |