summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2019-03-06 18:14:03 +0100
committerTor Egge <Tor.Egge@broadpark.no>2019-03-07 13:28:48 +0100
commit93c5cb44ba9f8fa36314b8a4d6d57b75422f8c29 (patch)
tree0aabf97b84256fe9c10d3d22de0b9ddcf73d9129 /document
parent99bcfb517bd0b57c24f81478c3767f1b8d369fb3 (diff)
Check for assignable tensor type when setting tensor in TensorFieldValue.
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp3
-rw-r--r--document/src/tests/serialization/vespadocumentserializer_test.cpp114
-rw-r--r--document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp53
-rw-r--r--document/src/vespa/document/base/exceptions.cpp2
-rw-r--r--document/src/vespa/document/base/exceptions.h3
-rw-r--r--document/src/vespa/document/fieldvalue/structfieldvalue.cpp9
-rw-r--r--document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp38
7 files changed, 197 insertions, 25 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 295372a42ca..71bc545897b 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/document/base/testdocman.h>
+#include <vespa/document/base/exceptions.h>
#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/fieldvalues.h>
#include <vespa/document/repo/configbuilder.h>
@@ -1050,7 +1051,7 @@ TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_spar
auto cellsTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense cells tensor
ASSERT_THROW(
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, std::move(cellsTensor))),
- vespalib::IllegalStateException);
+ document::WrongTensorTypeException);
}
struct TensorUpdateSerializeFixture {
diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp
index 624cdeff04e..c573eef6147 100644
--- a/document/src/tests/serialization/vespadocumentserializer_test.cpp
+++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp
@@ -10,6 +10,7 @@
#include <vespa/document/datatype/documenttype.h>
#include <vespa/document/datatype/weightedsetdatatype.h>
#include <vespa/document/datatype/mapdatatype.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/annotationreferencefieldvalue.h>
#include <vespa/document/fieldvalue/arrayfieldvalue.h>
#include <vespa/document/fieldvalue/bytefieldvalue.h>
@@ -139,6 +140,11 @@ template <> ReferenceFieldValue newFieldValue(const ReferenceFieldValue& value)
*value.getDataType()));
}
+template <> TensorFieldValue newFieldValue(const TensorFieldValue& value) {
+ return TensorFieldValue(dynamic_cast<const TensorDataType&>(
+ *value.getDataType()));
+}
+
template<typename T>
void testDeserializeAndClone(const T& value, const nbostream &stream, bool checkEqual=true) {
T read_value = newFieldValue(value);
@@ -833,13 +839,14 @@ createTensor(const TensorCells &cells, const TensorDimensions &dimensions) {
TEST("Require that tensors can be serialized")
{
- TensorFieldValue noTensorValue;
- TensorFieldValue emptyTensorValue;
- TensorFieldValue twoCellsTwoDimsValue;
+ TensorDataType xySparseTensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{},y{})"));
+ TensorFieldValue noTensorValue(xySparseTensorDataType);
+ TensorFieldValue emptyTensorValue(xySparseTensorDataType);
+ TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType);
nbostream stream;
serializeAndDeserialize(noTensorValue, stream);
stream.clear();
- emptyTensorValue = createTensor({}, {});
+ emptyTensorValue = createTensor({}, {"x", "y"});
serializeAndDeserialize(emptyTensorValue, stream);
stream.clear();
twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3},
@@ -855,19 +862,25 @@ TEST("Require that tensors can be serialized")
const int tensor_doc_type_id = 321;
const string tensor_field_name = "my_tensor";
-DocumenttypesConfig getTensorDocTypesConfig() {
+DocumenttypesConfig getTensorDocTypesConfig(const vespalib::string &tensorType) {
DocumenttypesConfigBuilderHelper builder;
builder.document(tensor_doc_type_id, "my_type",
Struct("my_type.header"),
Struct("my_type.body")
- .addField(tensor_field_name, DataType::T_TENSOR));
+ .addTensorField(tensor_field_name, tensorType));
return builder.config();
}
+DocumenttypesConfig getTensorDocTypesConfig() {
+ return getTensorDocTypesConfig("tensor(dimX{},dimY{})");
+}
+
const DocumentTypeRepo tensor_doc_repo(getTensorDocTypesConfig());
const FixedTypeRepo tensor_repo(tensor_doc_repo,
*tensor_doc_repo.getDocumentType(doc_type_id));
+const DocumentTypeRepo tensor_doc_repo1(getTensorDocTypesConfig("tensor(dimX{})"));
+
void serializeToFile(TensorFieldValue &value, const string &file_name) {
const DocumentType *type =
tensor_doc_repo.getDocumentType(tensor_doc_type_id);
@@ -881,7 +894,8 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) {
void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) {
const string data_dir = TEST_PATH("../../test/resources/tensor/");
- TensorFieldValue value;
+ TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type());
+ TensorFieldValue value(valueType);
if (tensor) {
value = tensor->clone();
}
@@ -901,6 +915,92 @@ TEST("Require that tensor deserialization matches Java") {
{ "dimX", "dimY" }));
}
+struct TensorDocFixture {
+ const DocumentTypeRepo &_docTypeRepo;
+ const DocumentType *_docType;
+ std::unique_ptr<Tensor> _tensor;
+ Document _doc;
+ vespalib::nbostream _blob;
+
+ TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
+ std::unique_ptr<Tensor> tensor);
+ ~TensorDocFixture();
+};
+
+TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo,
+ std::unique_ptr<Tensor> tensor)
+ : _docTypeRepo(docTypeRepo),
+ _docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)),
+ _tensor(std::move(tensor)),
+ _doc(*_docType, DocumentId("id:test:my_type::foo")),
+ _blob()
+{
+ auto fv = _doc.getField(tensor_field_name).createValue();
+ dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone();
+ _doc.setValue(tensor_field_name, *fv);
+ _doc.serialize(_blob);
+}
+
+TensorDocFixture::~TensorDocFixture() = default;
+
+struct DeserializedTensorDoc
+{
+ std::unique_ptr<Document> _doc;
+ std::unique_ptr<FieldValue> _fieldValue;
+
+ DeserializedTensorDoc();
+ ~DeserializedTensorDoc();
+
+ void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob);
+ const Tensor *getTensor() const;
+};
+
+DeserializedTensorDoc::DeserializedTensorDoc()
+ : _doc(),
+ _fieldValue()
+{
+}
+
+DeserializedTensorDoc::~DeserializedTensorDoc() = default;
+
+void
+DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob)
+{
+ vespalib::nbostream wrapStream(blob.peek(), blob.size());
+ _doc = std::make_unique<Document>(docTypeRepo, wrapStream, nullptr);
+ _fieldValue = _doc->getValue(tensor_field_name);
+}
+
+const Tensor *
+DeserializedTensorDoc::getTensor() const
+{
+ return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr().get();
+}
+
+TEST("Require that wrong tensor type hides tensor")
+{
+ TensorDocFixture f(tensor_doc_repo,
+ createTensor({ {{{"dimX", "a"},{"dimY", "bb"}}, 2.0 },
+ {{{"dimX", "ccc"},{"dimY", "dddd"}}, 3.0},
+ {{{"dimX", "e"},{"dimY","ff"}}, 5.0} },
+ { "dimX", "dimY" }));
+ TensorDocFixture f1(tensor_doc_repo1,
+ createTensor({ {{{"dimX", "a"}}, 20.0 },
+ {{{"dimX", "ccc"}}, 30.0} },
+ { "dimX" }));
+ DeserializedTensorDoc doc;
+ doc.setup(tensor_doc_repo, f._blob);
+ EXPECT_TRUE(doc.getTensor() != nullptr);
+ EXPECT_TRUE(doc.getTensor()->equals(*f._tensor));
+ doc.setup(tensor_doc_repo, f1._blob);
+ EXPECT_TRUE(doc.getTensor() == nullptr);
+ doc.setup(tensor_doc_repo1, f._blob);
+ EXPECT_TRUE(doc.getTensor() == nullptr);
+ doc.setup(tensor_doc_repo1, f1._blob);
+ EXPECT_TRUE(doc.getTensor() != nullptr);
+ EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor));
+}
+
struct RefFixture {
const DocumentType* ref_doc_type{doc_repo.getDocumentType(doc_with_ref_type_id)};
FixedTypeRepo fixed_repo{doc_repo, *ref_doc_type};
diff --git a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp
index 7a929ae26b4..08e615bccaf 100644
--- a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp
+++ b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp
@@ -4,26 +4,42 @@
#include <vespa/log/log.h>
LOG_SETUP("fieldvalue_test");
+#include <vespa/document/base/exceptions.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/types.h>
#include <vespa/eval/tensor/default_tensor.h>
#include <vespa/eval/tensor/tensor_factory.h>
+#include <vespa/eval/tensor/test/test_utils.h>
#include <vespa/vespalib/testkit/testapp.h>
using namespace document;
using namespace vespalib::tensor;
+using vespalib::eval::TensorSpec;
+using vespalib::eval::ValueType;
+using vespalib::tensor::test::makeTensor;
namespace
{
+TensorDataType xSparseTensorDataType(ValueType::from_spec("tensor(x{})"));
+TensorDataType xySparseTensorDataType(ValueType::from_spec("tensor(x{},y{})"));
+
Tensor::UP
createTensor(const TensorCells &cells, const TensorDimensions &dimensions) {
vespalib::tensor::DefaultTensor::builder builder;
return vespalib::tensor::TensorFactory::create(cells, dimensions, builder);
}
+std::unique_ptr<Tensor>
+makeSimpleTensor()
+{
+ return makeTensor<Tensor>(TensorSpec("tensor(x{},y{})").
+ add({{"x", "4"}, {"y", "5"}}, 7));
+}
+
FieldValue::UP clone(FieldValue &fv) {
auto ret = FieldValue::UP(fv.clone());
EXPECT_NOT_EQUAL(ret.get(), &fv);
@@ -35,10 +51,10 @@ FieldValue::UP clone(FieldValue &fv) {
}
TEST("require that TensorFieldValue can be assigned tensors and cloned") {
- TensorFieldValue noTensorValue;
- TensorFieldValue emptyTensorValue;
- TensorFieldValue twoCellsTwoDimsValue;
- emptyTensorValue = createTensor({}, {});
+ TensorFieldValue noTensorValue(xySparseTensorDataType);
+ TensorFieldValue emptyTensorValue(xySparseTensorDataType);
+ TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType);
+ emptyTensorValue = createTensor({}, {"x", "y"});
twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3},
{{{"x", "4"}, {"y", "5"}}, 7} },
{"x", "y"});
@@ -57,7 +73,7 @@ TEST("require that TensorFieldValue can be assigned tensors and cloned") {
EXPECT_NOT_EQUAL(*emptyClone, *twoClone);
EXPECT_NOT_EQUAL(*twoClone, *noneClone);
EXPECT_NOT_EQUAL(*twoClone, *emptyClone);
- TensorFieldValue twoCellsTwoDimsValue2;
+ TensorFieldValue twoCellsTwoDimsValue2(xySparseTensorDataType);
twoCellsTwoDimsValue2 =
createTensor({ {{{"y", "3"}}, 3},
{{{"x", "4"}, {"y", "5"}}, 7} },
@@ -69,11 +85,36 @@ TEST("require that TensorFieldValue can be assigned tensors and cloned") {
TEST("require that TensorFieldValue::toString works")
{
- TensorFieldValue tensorFieldValue;
+ TensorFieldValue tensorFieldValue(xSparseTensorDataType);
EXPECT_EQUAL("{TensorFieldValue: null}", tensorFieldValue.toString());
tensorFieldValue = createTensor({{{{"x","a"}}, 3}}, {"x"});
EXPECT_EQUAL("{TensorFieldValue: {\"dimensions\":[\"x\"],\"cells\":[{\"address\":{\"x\":\"a\"},\"value\":3}]}}", tensorFieldValue.toString());
}
+TEST("require that wrong tensor type for special case assign throws exception")
+{
+ TensorFieldValue tensorFieldValue(xSparseTensorDataType);
+ EXPECT_EXCEPTION(tensorFieldValue = makeSimpleTensor(),
+ document::WrongTensorTypeException,
+ "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'");
+}
+
+TEST("require that wrong tensor type for copy assign throws exception")
+{
+ TensorFieldValue tensorFieldValue(xSparseTensorDataType);
+ TensorFieldValue simpleTensorFieldValue(xySparseTensorDataType);
+ simpleTensorFieldValue = makeSimpleTensor();
+ EXPECT_EXCEPTION(tensorFieldValue = simpleTensorFieldValue,
+ document::WrongTensorTypeException,
+ "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'");
+}
+
+TEST("require that wrong tensor type for assignDeserialized throws exception")
+{
+ TensorFieldValue tensorFieldValue(xSparseTensorDataType);
+ EXPECT_EXCEPTION(tensorFieldValue.assignDeserialized(makeSimpleTensor()),
+ document::WrongTensorTypeException,
+ "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'");
+}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/document/src/vespa/document/base/exceptions.cpp b/document/src/vespa/document/base/exceptions.cpp
index 1dd891fc500..22c772cd987 100644
--- a/document/src/vespa/document/base/exceptions.cpp
+++ b/document/src/vespa/document/base/exceptions.cpp
@@ -104,4 +104,6 @@ DocumentTypeNotFoundException::~DocumentTypeNotFoundException() throw()
{
}
+VESPA_IMPLEMENT_EXCEPTION(WrongTensorTypeException, vespalib::Exception);
+
}
diff --git a/document/src/vespa/document/base/exceptions.h b/document/src/vespa/document/base/exceptions.h
index 49ddecc8441..f79d4f4da0e 100644
--- a/document/src/vespa/document/base/exceptions.h
+++ b/document/src/vespa/document/base/exceptions.h
@@ -138,5 +138,6 @@ public:
VESPA_DEFINE_EXCEPTION_SPINE(FieldNotFoundException);
};
-}
+VESPA_DEFINE_EXCEPTION(WrongTensorTypeException, vespalib::Exception);
+}
diff --git a/document/src/vespa/document/fieldvalue/structfieldvalue.cpp b/document/src/vespa/document/fieldvalue/structfieldvalue.cpp
index f2e2896d9c3..ac37970213c 100644
--- a/document/src/vespa/document/fieldvalue/structfieldvalue.cpp
+++ b/document/src/vespa/document/fieldvalue/structfieldvalue.cpp
@@ -176,8 +176,13 @@ void
createFV(FieldValue & value, const DocumentTypeRepo & repo, nbostream & stream, const DocumentType & doc_type, uint32_t version)
{
FixedTypeRepo frepo(repo, doc_type);
- VespaDocumentDeserializer deserializer(frepo, stream, version);
- deserializer.read(value);
+ try {
+ VespaDocumentDeserializer deserializer(frepo, stream, version);
+ deserializer.read(value);
+ } catch (WrongTensorTypeException &) {
+ // A tensor field will appear to have no tensor if the stored tensor
+ // cannot be assigned to the tensor field.
+ }
}
}
diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
index 99ee030942f..399720e2354 100644
--- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
+++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "tensorfieldvalue.h"
+#include <vespa/document/base/exceptions.h>
#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/vespalib/util/xmlstream.h>
#include <vespa/eval/tensor/tensor.h>
@@ -12,6 +13,7 @@
using vespalib::slime::JsonFormat;
using vespalib::tensor::Tensor;
using vespalib::tensor::SlimeBinaryFormat;
+using vespalib::eval::ValueType;
using namespace vespalib::xml;
namespace document {
@@ -20,6 +22,13 @@ namespace {
TensorDataType emptyTensorDataType;
+vespalib::string makeWrongTensorTypeMsg(const ValueType &fieldTensorType, const ValueType &tensorType)
+{
+ return vespalib::make_string("Field tensor type is '%s' but tensor type is '%s'",
+ fieldTensorType.to_spec().c_str(),
+ tensorType.to_spec().c_str());
+}
+
}
TensorFieldValue::TensorFieldValue()
@@ -66,12 +75,17 @@ TensorFieldValue &
TensorFieldValue::operator=(const TensorFieldValue &rhs)
{
if (this != &rhs) {
- if (rhs._tensor) {
- _tensor = rhs._tensor->clone();
+ if (&_dataType == &rhs._dataType || !rhs._tensor ||
+ _dataType.isAssignableType(rhs._tensor->type())) {
+ if (rhs._tensor) {
+ _tensor = rhs._tensor->clone();
+ } else {
+ _tensor.reset();
+ }
+ _altered = true;
} else {
- _tensor.reset();
+ throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs._tensor->type()), VESPA_STRLOC);
}
- _altered = true;
}
return *this;
}
@@ -80,8 +94,12 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs)
TensorFieldValue &
TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs)
{
- _tensor = std::move(rhs);
- _altered = true;
+ if (!rhs || _dataType.isAssignableType(rhs->type())) {
+ _tensor = std::move(rhs);
+ _altered = true;
+ } else {
+ throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs->type()), VESPA_STRLOC);
+ }
return *this;
}
@@ -165,8 +183,12 @@ TensorFieldValue::assign(const FieldValue &value)
void
TensorFieldValue::assignDeserialized(std::unique_ptr<Tensor> rhs)
{
- _tensor = std::move(rhs);
- _altered = false; // Serialized form already exists
+ if (!rhs || _dataType.isAssignableType(rhs->type())) {
+ _tensor = std::move(rhs);
+ _altered = false; // Serialized form already exists
+ } else {
+ throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs->type()), VESPA_STRLOC);
+ }
}