aboutsummaryrefslogtreecommitdiffstats
path: root/document/src/tests/serialization/vespadocumentserializer_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'document/src/tests/serialization/vespadocumentserializer_test.cpp')
-rw-r--r--document/src/tests/serialization/vespadocumentserializer_test.cpp114
1 files changed, 107 insertions, 7 deletions
diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp
index 624cdeff04e..c573eef6147 100644
--- a/document/src/tests/serialization/vespadocumentserializer_test.cpp
+++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp
@@ -10,6 +10,7 @@
#include <vespa/document/datatype/documenttype.h>
#include <vespa/document/datatype/weightedsetdatatype.h>
#include <vespa/document/datatype/mapdatatype.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/annotationreferencefieldvalue.h>
#include <vespa/document/fieldvalue/arrayfieldvalue.h>
#include <vespa/document/fieldvalue/bytefieldvalue.h>
@@ -139,6 +140,11 @@ template <> ReferenceFieldValue newFieldValue(const ReferenceFieldValue& value)
*value.getDataType()));
}
+template <> TensorFieldValue newFieldValue(const TensorFieldValue& value) {
+ return TensorFieldValue(dynamic_cast<const TensorDataType&>(
+ *value.getDataType()));
+}
+
template<typename T>
void testDeserializeAndClone(const T& value, const nbostream &stream, bool checkEqual=true) {
T read_value = newFieldValue(value);
@@ -833,13 +839,14 @@ createTensor(const TensorCells &cells, const TensorDimensions &dimensions) {
TEST("Require that tensors can be serialized")
{
- TensorFieldValue noTensorValue;
- TensorFieldValue emptyTensorValue;
- TensorFieldValue twoCellsTwoDimsValue;
+ TensorDataType xySparseTensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{},y{})"));
+ TensorFieldValue noTensorValue(xySparseTensorDataType);
+ TensorFieldValue emptyTensorValue(xySparseTensorDataType);
+ TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType);
nbostream stream;
serializeAndDeserialize(noTensorValue, stream);
stream.clear();
- emptyTensorValue = createTensor({}, {});
+ emptyTensorValue = createTensor({}, {"x", "y"});
serializeAndDeserialize(emptyTensorValue, stream);
stream.clear();
twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3},
@@ -855,19 +862,25 @@ TEST("Require that tensors can be serialized")
const int tensor_doc_type_id = 321;
const string tensor_field_name = "my_tensor";
-DocumenttypesConfig getTensorDocTypesConfig() {
+DocumenttypesConfig getTensorDocTypesConfig(const vespalib::string &tensorType) {
DocumenttypesConfigBuilderHelper builder;
builder.document(tensor_doc_type_id, "my_type",
Struct("my_type.header"),
Struct("my_type.body")
- .addField(tensor_field_name, DataType::T_TENSOR));
+ .addTensorField(tensor_field_name, tensorType));
return builder.config();
}
+DocumenttypesConfig getTensorDocTypesConfig() {
+ return getTensorDocTypesConfig("tensor(dimX{},dimY{})");
+}
+
const DocumentTypeRepo tensor_doc_repo(getTensorDocTypesConfig());
const FixedTypeRepo tensor_repo(tensor_doc_repo,
*tensor_doc_repo.getDocumentType(doc_type_id));
+const DocumentTypeRepo tensor_doc_repo1(getTensorDocTypesConfig("tensor(dimX{})"));
+
void serializeToFile(TensorFieldValue &value, const string &file_name) {
const DocumentType *type =
tensor_doc_repo.getDocumentType(tensor_doc_type_id);
@@ -881,7 +894,8 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) {
void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) {
const string data_dir = TEST_PATH("../../test/resources/tensor/");
- TensorFieldValue value;
+ TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type());
+ TensorFieldValue value(valueType);
if (tensor) {
value = tensor->clone();
}
@@ -901,6 +915,92 @@ TEST("Require that tensor deserialization matches Java") {
{ "dimX", "dimY" }));
}
+struct TensorDocFixture {
+ const DocumentTypeRepo &_docTypeRepo;
+ const DocumentType *_docType;
+ std::unique_ptr<Tensor> _tensor;
+ Document _doc;
+ vespalib::nbostream _blob;
+
+ TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
+ std::unique_ptr<Tensor> tensor);
+ ~TensorDocFixture();
+};
+
+TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
+ std::unique_ptr<Tensor> tensor)
+ : _docTypeRepo(docTypeRepo),
+ _docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)),
+ _tensor(std::move(tensor)),
+ _doc(*_docType, DocumentId("id:test:my_type::foo")),
+ _blob()
+{
+ auto fv = _doc.getField(tensor_field_name).createValue();
+ dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone();
+ _doc.setValue(tensor_field_name, *fv);
+ _doc.serialize(_blob);
+}
+
+TensorDocFixture::~TensorDocFixture() = default;
+
+struct DeserializedTensorDoc
+{
+ std::unique_ptr<Document> _doc;
+ std::unique_ptr<FieldValue> _fieldValue;
+
+ DeserializedTensorDoc();
+ ~DeserializedTensorDoc();
+
+ void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob);
+ const Tensor *getTensor() const;
+};
+
+DeserializedTensorDoc::DeserializedTensorDoc()
+ : _doc(),
+ _fieldValue()
+{
+}
+
+DeserializedTensorDoc::~DeserializedTensorDoc() = default;
+
+void
+DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob)
+{
+ vespalib::nbostream wrapStream(blob.peek(), blob.size());
+ _doc = std::make_unique<Document>(docTypeRepo, wrapStream, nullptr);
+ _fieldValue = _doc->getValue(tensor_field_name);
+}
+
+const Tensor *
+DeserializedTensorDoc::getTensor() const
+{
+ return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr().get();
+}
+
+TEST("Require that wrong tensor type hides tensor")
+{
+ TensorDocFixture f(tensor_doc_repo,
+ createTensor({ {{{"dimX", "a"},{"dimY", "bb"}}, 2.0 },
+ {{{"dimX", "ccc"},{"dimY", "dddd"}}, 3.0},
+ {{{"dimX", "e"},{"dimY","ff"}}, 5.0} },
+ { "dimX", "dimY" }));
+ TensorDocFixture f1(tensor_doc_repo1,
+ createTensor({ {{{"dimX", "a"}}, 20.0 },
+ {{{"dimX", "ccc"}}, 30.0} },
+ { "dimX" }));
+ DeserializedTensorDoc doc;
+ doc.setup(tensor_doc_repo, f._blob);
+ EXPECT_TRUE(doc.getTensor() != nullptr);
+ EXPECT_TRUE(doc.getTensor()->equals(*f._tensor));
+ doc.setup(tensor_doc_repo, f1._blob);
+ EXPECT_TRUE(doc.getTensor() == nullptr);
+ doc.setup(tensor_doc_repo1, f._blob);
+ EXPECT_TRUE(doc.getTensor() == nullptr);
+ doc.setup(tensor_doc_repo1, f1._blob);
+ EXPECT_TRUE(doc.getTensor() != nullptr);
+ EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor));
+}
+
struct RefFixture {
const DocumentType* ref_doc_type{doc_repo.getDocumentType(doc_with_ref_type_id)};
FixedTypeRepo fixed_repo{doc_repo, *ref_doc_type};