diff options
Diffstat (limited to 'document/src/tests/serialization/vespadocumentserializer_test.cpp')
-rw-r--r-- | document/src/tests/serialization/vespadocumentserializer_test.cpp | 114 |
1 files changed, 107 insertions, 7 deletions
diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp index 624cdeff04e..c573eef6147 100644 --- a/document/src/tests/serialization/vespadocumentserializer_test.cpp +++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp @@ -10,6 +10,7 @@ #include <vespa/document/datatype/documenttype.h> #include <vespa/document/datatype/weightedsetdatatype.h> #include <vespa/document/datatype/mapdatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/annotationreferencefieldvalue.h> #include <vespa/document/fieldvalue/arrayfieldvalue.h> #include <vespa/document/fieldvalue/bytefieldvalue.h> @@ -139,6 +140,11 @@ template <> ReferenceFieldValue newFieldValue(const ReferenceFieldValue& value) *value.getDataType())); } +template <> TensorFieldValue newFieldValue(const TensorFieldValue& value) { + return TensorFieldValue(dynamic_cast<const TensorDataType&>( + *value.getDataType())); +} + template<typename T> void testDeserializeAndClone(const T& value, const nbostream &stream, bool checkEqual=true) { T read_value = newFieldValue(value); @@ -833,13 +839,14 @@ createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { TEST("Require that tensors can be serialized") { - TensorFieldValue noTensorValue; - TensorFieldValue emptyTensorValue; - TensorFieldValue twoCellsTwoDimsValue; + TensorDataType xySparseTensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{},y{})")); + TensorFieldValue noTensorValue(xySparseTensorDataType); + TensorFieldValue emptyTensorValue(xySparseTensorDataType); + TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType); nbostream stream; serializeAndDeserialize(noTensorValue, stream); stream.clear(); - emptyTensorValue = createTensor({}, {}); + emptyTensorValue = createTensor({}, {"x", "y"}); serializeAndDeserialize(emptyTensorValue, stream); stream.clear(); twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3}, @@ -855,19 +862,25 @@ TEST("Require that tensors can be serialized") const int tensor_doc_type_id = 321; const string tensor_field_name = "my_tensor"; -DocumenttypesConfig getTensorDocTypesConfig() { +DocumenttypesConfig getTensorDocTypesConfig(const vespalib::string &tensorType) { DocumenttypesConfigBuilderHelper builder; builder.document(tensor_doc_type_id, "my_type", Struct("my_type.header"), Struct("my_type.body") - .addField(tensor_field_name, DataType::T_TENSOR)); + .addTensorField(tensor_field_name, tensorType)); return builder.config(); } +DocumenttypesConfig getTensorDocTypesConfig() { + return getTensorDocTypesConfig("tensor(dimX{},dimY{})"); +} + const DocumentTypeRepo tensor_doc_repo(getTensorDocTypesConfig()); const FixedTypeRepo tensor_repo(tensor_doc_repo, *tensor_doc_repo.getDocumentType(doc_type_id)); +const DocumentTypeRepo tensor_doc_repo1(getTensorDocTypesConfig("tensor(dimX{})")); + void serializeToFile(TensorFieldValue &value, const string &file_name) { const DocumentType *type = tensor_doc_repo.getDocumentType(tensor_doc_type_id); @@ -881,7 +894,8 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) { void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) { const string data_dir = TEST_PATH("../../test/resources/tensor/"); - TensorFieldValue value; + TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type()); + TensorFieldValue value(valueType); if (tensor) { value = tensor->clone(); } @@ -901,6 +915,92 @@ TEST("Require that tensor deserialization matches Java") { { "dimX", "dimY" })); } +struct TensorDocFixture { + const DocumentTypeRepo &_docTypeRepo; + const DocumentType *_docType; + std::unique_ptr<Tensor> _tensor; + Document _doc; + vespalib::nbostream _blob; + + TensorDocFixture(const DocumentTypeRepo &docTypeRepo, + std::unique_ptr<Tensor> tensor); + ~TensorDocFixture(); +}; + +TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo, + std::unique_ptr<Tensor> tensor) + : _docTypeRepo(docTypeRepo), + _docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)), + _tensor(std::move(tensor)), + _doc(*_docType, DocumentId("id:test:my_type::foo")), + _blob() +{ + auto fv = _doc.getField(tensor_field_name).createValue(); + dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone(); + _doc.setValue(tensor_field_name, *fv); + _doc.serialize(_blob); +} + +TensorDocFixture::~TensorDocFixture() = default; + +struct DeserializedTensorDoc +{ + std::unique_ptr<Document> _doc; + std::unique_ptr<FieldValue> _fieldValue; + + DeserializedTensorDoc(); + ~DeserializedTensorDoc(); + + void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob); + const Tensor *getTensor() const; +}; + +DeserializedTensorDoc::DeserializedTensorDoc() + : _doc(), + _fieldValue() +{ +} + +DeserializedTensorDoc::~DeserializedTensorDoc() = default; + +void +DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob) +{ + vespalib::nbostream wrapStream(blob.peek(), blob.size()); + _doc = std::make_unique<Document>(docTypeRepo, wrapStream, nullptr); + _fieldValue = _doc->getValue(tensor_field_name); +} + +const Tensor * +DeserializedTensorDoc::getTensor() const +{ + return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr().get(); +} + +TEST("Require that wrong tensor type hides tensor") +{ + TensorDocFixture f(tensor_doc_repo, + createTensor({ {{{"dimX", "a"},{"dimY", "bb"}}, 2.0 }, + {{{"dimX", "ccc"},{"dimY", "dddd"}}, 3.0}, + {{{"dimX", "e"},{"dimY","ff"}}, 5.0} }, + { "dimX", "dimY" })); + TensorDocFixture f1(tensor_doc_repo1, + createTensor({ {{{"dimX", "a"}}, 20.0 }, + {{{"dimX", "ccc"}}, 30.0} }, + { "dimX" })); + DeserializedTensorDoc doc; + doc.setup(tensor_doc_repo, f._blob); + EXPECT_TRUE(doc.getTensor() != nullptr); + EXPECT_TRUE(doc.getTensor()->equals(*f._tensor)); + doc.setup(tensor_doc_repo, f1._blob); + EXPECT_TRUE(doc.getTensor() == nullptr); + doc.setup(tensor_doc_repo1, f._blob); + EXPECT_TRUE(doc.getTensor() == nullptr); + doc.setup(tensor_doc_repo1, f1._blob); + EXPECT_TRUE(doc.getTensor() != nullptr); + EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor)); +} + struct RefFixture { const DocumentType* ref_doc_type{doc_repo.getDocumentType(doc_with_ref_type_id)}; FixedTypeRepo fixed_repo{doc_repo, *ref_doc_type}; |