diff options
author | Tor Egge <Tor.Egge@broadpark.no> | 2019-03-06 18:14:03 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@broadpark.no> | 2019-03-07 13:28:48 +0100 |
commit | 93c5cb44ba9f8fa36314b8a4d6d57b75422f8c29 (patch) | |
tree | 0aabf97b84256fe9c10d3d22de0b9ddcf73d9129 | |
parent | 99bcfb517bd0b57c24f81478c3767f1b8d369fb3 (diff) |
Check for assignable tensor type when setting tensor in TensorFieldValue.
14 files changed, 281 insertions, 37 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 295372a42ca..71bc545897b 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/document/base/testdocman.h> +#include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/fieldvalues.h> #include <vespa/document/repo/configbuilder.h> @@ -1050,7 +1051,7 @@ TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_spar auto cellsTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense cells tensor ASSERT_THROW( f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, std::move(cellsTensor))), - vespalib::IllegalStateException); + document::WrongTensorTypeException); } struct TensorUpdateSerializeFixture { diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp index 624cdeff04e..c573eef6147 100644 --- a/document/src/tests/serialization/vespadocumentserializer_test.cpp +++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp @@ -10,6 +10,7 @@ #include <vespa/document/datatype/documenttype.h> #include <vespa/document/datatype/weightedsetdatatype.h> #include <vespa/document/datatype/mapdatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/annotationreferencefieldvalue.h> #include <vespa/document/fieldvalue/arrayfieldvalue.h> #include <vespa/document/fieldvalue/bytefieldvalue.h> @@ -139,6 +140,11 @@ template <> ReferenceFieldValue newFieldValue(const ReferenceFieldValue& value) *value.getDataType())); } +template <> TensorFieldValue newFieldValue(const TensorFieldValue& value) { + return TensorFieldValue(dynamic_cast<const TensorDataType&>( + *value.getDataType())); +} + template<typename T> void testDeserializeAndClone(const T& value, const nbostream &stream, bool checkEqual=true) { T read_value = newFieldValue(value); @@ -833,13 +839,14 @@ createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { TEST("Require that tensors can be serialized") { - TensorFieldValue noTensorValue; - TensorFieldValue emptyTensorValue; - TensorFieldValue twoCellsTwoDimsValue; + TensorDataType xySparseTensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{},y{})")); + TensorFieldValue noTensorValue(xySparseTensorDataType); + TensorFieldValue emptyTensorValue(xySparseTensorDataType); + TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType); nbostream stream; serializeAndDeserialize(noTensorValue, stream); stream.clear(); - emptyTensorValue = createTensor({}, {}); + emptyTensorValue = createTensor({}, {"x", "y"}); serializeAndDeserialize(emptyTensorValue, stream); stream.clear(); twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3}, @@ -855,19 +862,25 @@ TEST("Require that tensors can be serialized") const int tensor_doc_type_id = 321; const string tensor_field_name = "my_tensor"; -DocumenttypesConfig getTensorDocTypesConfig() { +DocumenttypesConfig getTensorDocTypesConfig(const vespalib::string &tensorType) { DocumenttypesConfigBuilderHelper builder; builder.document(tensor_doc_type_id, "my_type", Struct("my_type.header"), Struct("my_type.body") - .addField(tensor_field_name, DataType::T_TENSOR)); + .addTensorField(tensor_field_name, tensorType)); return builder.config(); } +DocumenttypesConfig getTensorDocTypesConfig() { + return getTensorDocTypesConfig("tensor(dimX{},dimY{})"); +} + const DocumentTypeRepo tensor_doc_repo(getTensorDocTypesConfig()); const FixedTypeRepo tensor_repo(tensor_doc_repo, *tensor_doc_repo.getDocumentType(doc_type_id)); +const DocumentTypeRepo tensor_doc_repo1(getTensorDocTypesConfig("tensor(dimX{})")); + void serializeToFile(TensorFieldValue &value, const string &file_name) { const DocumentType *type = tensor_doc_repo.getDocumentType(tensor_doc_type_id); @@ -881,7 +894,8 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) { void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) { const string data_dir = TEST_PATH("../../test/resources/tensor/"); - TensorFieldValue value; + TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type()); + TensorFieldValue value(valueType); if (tensor) { value = tensor->clone(); } @@ -901,6 +915,92 @@ TEST("Require that tensor deserialization matches Java") { { "dimX", "dimY" })); } +struct TensorDocFixture { + const DocumentTypeRepo &_docTypeRepo; + const DocumentType *_docType; + std::unique_ptr<Tensor> _tensor; + Document _doc; + vespalib::nbostream _blob; + + TensorDocFixture(const DocumentTypeRepo &docTypeRepo, + std::unique_ptr<Tensor> tensor); + ~TensorDocFixture(); +}; + +TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo, + std::unique_ptr<Tensor> tensor) + : _docTypeRepo(docTypeRepo), + _docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)), + _tensor(std::move(tensor)), + _doc(*_docType, DocumentId("id:test:my_type::foo")), + _blob() +{ + auto fv = _doc.getField(tensor_field_name).createValue(); + dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone(); + _doc.setValue(tensor_field_name, *fv); + _doc.serialize(_blob); +} + +TensorDocFixture::~TensorDocFixture() = default; + +struct DeserializedTensorDoc +{ + std::unique_ptr<Document> _doc; + std::unique_ptr<FieldValue> _fieldValue; + + DeserializedTensorDoc(); + ~DeserializedTensorDoc(); + + void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob); + const Tensor *getTensor() const; +}; + +DeserializedTensorDoc::DeserializedTensorDoc() + : _doc(), + _fieldValue() +{ +} + +DeserializedTensorDoc::~DeserializedTensorDoc() = default; + +void +DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob) +{ + vespalib::nbostream wrapStream(blob.peek(), blob.size()); + _doc = std::make_unique<Document>(docTypeRepo, wrapStream, nullptr); + _fieldValue = _doc->getValue(tensor_field_name); +} + +const Tensor * +DeserializedTensorDoc::getTensor() const +{ + return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr().get(); +} + +TEST("Require that wrong tensor type hides tensor") +{ + TensorDocFixture f(tensor_doc_repo, + createTensor({ {{{"dimX", "a"},{"dimY", "bb"}}, 2.0 }, + {{{"dimX", "ccc"},{"dimY", "dddd"}}, 3.0}, + {{{"dimX", "e"},{"dimY","ff"}}, 5.0} }, + { "dimX", "dimY" })); + TensorDocFixture f1(tensor_doc_repo1, + createTensor({ {{{"dimX", "a"}}, 20.0 }, + {{{"dimX", "ccc"}}, 30.0} }, + { "dimX" })); + DeserializedTensorDoc doc; + doc.setup(tensor_doc_repo, f._blob); + EXPECT_TRUE(doc.getTensor() != nullptr); + EXPECT_TRUE(doc.getTensor()->equals(*f._tensor)); + doc.setup(tensor_doc_repo, f1._blob); + EXPECT_TRUE(doc.getTensor() == nullptr); + doc.setup(tensor_doc_repo1, f._blob); + EXPECT_TRUE(doc.getTensor() == nullptr); + doc.setup(tensor_doc_repo1, f1._blob); + EXPECT_TRUE(doc.getTensor() != nullptr); + EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor)); +} + struct RefFixture { const DocumentType* ref_doc_type{doc_repo.getDocumentType(doc_with_ref_type_id)}; FixedTypeRepo fixed_repo{doc_repo, *ref_doc_type}; diff --git a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp index 7a929ae26b4..08e615bccaf 100644 --- a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp +++ b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp @@ -4,26 +4,42 @@ #include <vespa/log/log.h> LOG_SETUP("fieldvalue_test"); +#include <vespa/document/base/exceptions.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/types.h> #include <vespa/eval/tensor/default_tensor.h> #include <vespa/eval/tensor/tensor_factory.h> +#include <vespa/eval/tensor/test/test_utils.h> #include <vespa/vespalib/testkit/testapp.h> using namespace document; using namespace vespalib::tensor; +using vespalib::eval::TensorSpec; +using vespalib::eval::ValueType; +using vespalib::tensor::test::makeTensor; namespace { +TensorDataType xSparseTensorDataType(ValueType::from_spec("tensor(x{})")); +TensorDataType xySparseTensorDataType(ValueType::from_spec("tensor(x{},y{})")); + Tensor::UP createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { vespalib::tensor::DefaultTensor::builder builder; return vespalib::tensor::TensorFactory::create(cells, dimensions, builder); } +std::unique_ptr<Tensor> +makeSimpleTensor() +{ + return makeTensor<Tensor>(TensorSpec("tensor(x{},y{})"). + add({{"x", "4"}, {"y", "5"}}, 7)); +} + FieldValue::UP clone(FieldValue &fv) { auto ret = FieldValue::UP(fv.clone()); EXPECT_NOT_EQUAL(ret.get(), &fv); @@ -35,10 +51,10 @@ FieldValue::UP clone(FieldValue &fv) { } TEST("require that TensorFieldValue can be assigned tensors and cloned") { - TensorFieldValue noTensorValue; - TensorFieldValue emptyTensorValue; - TensorFieldValue twoCellsTwoDimsValue; - emptyTensorValue = createTensor({}, {}); + TensorFieldValue noTensorValue(xySparseTensorDataType); + TensorFieldValue emptyTensorValue(xySparseTensorDataType); + TensorFieldValue twoCellsTwoDimsValue(xySparseTensorDataType); + emptyTensorValue = createTensor({}, {"x", "y"}); twoCellsTwoDimsValue = createTensor({ {{{"y", "3"}}, 3}, {{{"x", "4"}, {"y", "5"}}, 7} }, {"x", "y"}); @@ -57,7 +73,7 @@ TEST("require that TensorFieldValue can be assigned tensors and cloned") { EXPECT_NOT_EQUAL(*emptyClone, *twoClone); EXPECT_NOT_EQUAL(*twoClone, *noneClone); EXPECT_NOT_EQUAL(*twoClone, *emptyClone); - TensorFieldValue twoCellsTwoDimsValue2; + TensorFieldValue twoCellsTwoDimsValue2(xySparseTensorDataType); twoCellsTwoDimsValue2 = createTensor({ {{{"y", "3"}}, 3}, {{{"x", "4"}, {"y", "5"}}, 7} }, @@ -69,11 +85,36 @@ TEST("require that TensorFieldValue can be assigned tensors and cloned") { TEST("require that TensorFieldValue::toString works") { - TensorFieldValue tensorFieldValue; + TensorFieldValue tensorFieldValue(xSparseTensorDataType); EXPECT_EQUAL("{TensorFieldValue: null}", tensorFieldValue.toString()); tensorFieldValue = createTensor({{{{"x","a"}}, 3}}, {"x"}); EXPECT_EQUAL("{TensorFieldValue: {\"dimensions\":[\"x\"],\"cells\":[{\"address\":{\"x\":\"a\"},\"value\":3}]}}", tensorFieldValue.toString()); } +TEST("require that wrong tensor type for special case assign throws exception") +{ + TensorFieldValue tensorFieldValue(xSparseTensorDataType); + EXPECT_EXCEPTION(tensorFieldValue = makeSimpleTensor(), + document::WrongTensorTypeException, + "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'"); +} + +TEST("require that wrong tensor type for copy assign throws exception") +{ + TensorFieldValue tensorFieldValue(xSparseTensorDataType); + TensorFieldValue simpleTensorFieldValue(xySparseTensorDataType); + simpleTensorFieldValue = makeSimpleTensor(); + EXPECT_EXCEPTION(tensorFieldValue = simpleTensorFieldValue, + document::WrongTensorTypeException, + "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'"); +} + +TEST("require that wrong tensor type for assignDeserialized throws exception") +{ + TensorFieldValue tensorFieldValue(xSparseTensorDataType); + EXPECT_EXCEPTION(tensorFieldValue.assignDeserialized(makeSimpleTensor()), + document::WrongTensorTypeException, + "WrongTensorTypeException: Field tensor type is 'tensor(x{})' but tensor type is 'tensor(x{},y{})'"); +} TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/document/src/vespa/document/base/exceptions.cpp b/document/src/vespa/document/base/exceptions.cpp index 1dd891fc500..22c772cd987 100644 --- a/document/src/vespa/document/base/exceptions.cpp +++ b/document/src/vespa/document/base/exceptions.cpp @@ -104,4 +104,6 @@ DocumentTypeNotFoundException::~DocumentTypeNotFoundException() throw() { } +VESPA_IMPLEMENT_EXCEPTION(WrongTensorTypeException, vespalib::Exception); + } diff --git a/document/src/vespa/document/base/exceptions.h b/document/src/vespa/document/base/exceptions.h index 49ddecc8441..f79d4f4da0e 100644 --- a/document/src/vespa/document/base/exceptions.h +++ b/document/src/vespa/document/base/exceptions.h @@ -138,5 +138,6 @@ public: VESPA_DEFINE_EXCEPTION_SPINE(FieldNotFoundException); }; -} +VESPA_DEFINE_EXCEPTION(WrongTensorTypeException, vespalib::Exception); +} diff --git a/document/src/vespa/document/fieldvalue/structfieldvalue.cpp b/document/src/vespa/document/fieldvalue/structfieldvalue.cpp index f2e2896d9c3..ac37970213c 100644 --- a/document/src/vespa/document/fieldvalue/structfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/structfieldvalue.cpp @@ -176,8 +176,13 @@ void createFV(FieldValue & value, const DocumentTypeRepo & repo, nbostream & stream, const DocumentType & doc_type, uint32_t version) { FixedTypeRepo frepo(repo, doc_type); - VespaDocumentDeserializer deserializer(frepo, stream, version); - deserializer.read(value); + try { + VespaDocumentDeserializer deserializer(frepo, stream, version); + deserializer.read(value); + } catch (WrongTensorTypeException &) { + // A tensor field will appear to have no tensor if the stored tensor + // cannot be assigned to the tensor field. + } } } diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp index 99ee030942f..399720e2354 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensorfieldvalue.h" +#include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/vespalib/util/xmlstream.h> #include <vespa/eval/tensor/tensor.h> @@ -12,6 +13,7 @@ using vespalib::slime::JsonFormat; using vespalib::tensor::Tensor; using vespalib::tensor::SlimeBinaryFormat; +using vespalib::eval::ValueType; using namespace vespalib::xml; namespace document { @@ -20,6 +22,13 @@ namespace { TensorDataType emptyTensorDataType; +vespalib::string makeWrongTensorTypeMsg(const ValueType &fieldTensorType, const ValueType &tensorType) +{ + return vespalib::make_string("Field tensor type is '%s' but tensor type is '%s'", + fieldTensorType.to_spec().c_str(), + tensorType.to_spec().c_str()); +} + } TensorFieldValue::TensorFieldValue() @@ -66,12 +75,17 @@ TensorFieldValue & TensorFieldValue::operator=(const TensorFieldValue &rhs) { if (this != &rhs) { - if (rhs._tensor) { - _tensor = rhs._tensor->clone(); + if (&_dataType == &rhs._dataType || !rhs._tensor || + _dataType.isAssignableType(rhs._tensor->type())) { + if (rhs._tensor) { + _tensor = rhs._tensor->clone(); + } else { + _tensor.reset(); + } + _altered = true; } else { - _tensor.reset(); + throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs._tensor->type()), VESPA_STRLOC); } - _altered = true; } return *this; } @@ -80,8 +94,12 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs) TensorFieldValue & TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs) { - _tensor = std::move(rhs); - _altered = true; + if (!rhs || _dataType.isAssignableType(rhs->type())) { + _tensor = std::move(rhs); + _altered = true; + } else { + throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs->type()), VESPA_STRLOC); + } return *this; } @@ -165,8 +183,12 @@ TensorFieldValue::assign(const FieldValue &value) void TensorFieldValue::assignDeserialized(std::unique_ptr<Tensor> rhs) { - _tensor = std::move(rhs); - _altered = false; // Serialized form already exists + if (!rhs || _dataType.isAssignableType(rhs->type())) { + _tensor = std::move(rhs); + _altered = false; // Serialized form already exists + } else { + throw WrongTensorTypeException(makeWrongTensorTypeMsg(_dataType.getTensorType(), rhs->type()), VESPA_STRLOC); + } } diff --git a/eval/src/vespa/eval/tensor/test/test_utils.h b/eval/src/vespa/eval/tensor/test/test_utils.h index f1bb4b7b603..5daae74284b 100644 --- a/eval/src/vespa/eval/tensor/test/test_utils.h +++ b/eval/src/vespa/eval/tensor/test/test_utils.h @@ -9,14 +9,14 @@ namespace vespalib::tensor::test { template <typename T> -std::unique_ptr<const T> +std::unique_ptr<T> makeTensor(const vespalib::eval::TensorSpec &spec) { auto value = DefaultTensorEngine::ref().from_spec(spec); - const T *tensor = dynamic_cast<const T *>(value->as_tensor()); + T *tensor = dynamic_cast<T *>(value.get()); ASSERT_TRUE(tensor); value.release(); - return std::unique_ptr<const T>(tensor); + return std::unique_ptr<T>(tensor); } } diff --git a/searchcore/src/tests/proton/attribute/attribute_test.cpp b/searchcore/src/tests/proton/attribute/attribute_test.cpp index a643dda01c6..7b7d25d2d52 100644 --- a/searchcore/src/tests/proton/attribute/attribute_test.cpp +++ b/searchcore/src/tests/proton/attribute/attribute_test.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/config-attributes.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/document.h> #include <vespa/document/predicate/predicate_slime_builder.h> #include <vespa/document/update/arithmeticvalueupdate.h> @@ -683,7 +684,8 @@ TEST_F("require that attribute writer handles tensor assign update", Fixture) DocumentUpdate upd(*builder.getDocumentTypeRepo(), dt, DocumentId("doc::1")); auto new_tensor = createTensor({ {{{"x", "8"}, {"y", "9"}}, 11} }, {"x", "y"}); - TensorFieldValue new_value; + TensorDataType xySparseTensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{},y{})")); + TensorFieldValue new_value(xySparseTensorDataType); new_value = new_tensor->clone(); upd.addUpdate(FieldUpdate(upd.getType().getField("a1")) .addUpdate(AssignValueUpdate(new_value))); diff --git a/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp b/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp index 73d071be96f..b82fec85d47 100644 --- a/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp +++ b/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp @@ -16,6 +16,7 @@ #include <vespa/document/datatype/urldatatype.h> #include <vespa/document/datatype/weightedsetdatatype.h> #include <vespa/document/datatype/referencedatatype.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/arrayfieldvalue.h> #include <vespa/document/fieldvalue/bytefieldvalue.h> #include <vespa/document/fieldvalue/document.h> @@ -82,6 +83,7 @@ using document::SpanTree; using document::StringFieldValue; using document::StructDataType; using document::StructFieldValue; +using document::TensorDataType; using document::TensorFieldValue; using document::UrlDataType; using document::WeightedSetDataType; @@ -92,6 +94,7 @@ using search::linguistics::TERM; using vespa::config::search::SummarymapConfig; using vespa::config::search::SummarymapConfigBuilder; using vespalib::Slime; +using vespalib::eval::ValueType; using vespalib::geo::ZCurve; using vespalib::slime::Cursor; using vespalib::string; @@ -230,7 +233,7 @@ DocumenttypesConfig getDocumenttypesConfig() { .addField("float", DataType::T_FLOAT) .addField("chinese", DataType::T_STRING) .addField("predicate", DataType::T_PREDICATE) - .addField("tensor", DataType::T_TENSOR) + .addTensorField("tensor", "tensor(x{},y{})") .addField("ref", ref_type_id) .addField("nested", Struct("indexingdocument.header.nested") .addField("inner_ref", ref_type_id)), @@ -683,7 +686,8 @@ createTensor(const TensorCells &cells, const TensorDimensions &dimensions) { void Test::requireThatTensorIsNotConverted() { - TensorFieldValue tensorFieldValue; + TensorDataType tensorDataType(ValueType::from_spec("tensor(x{},y{})")); + TensorFieldValue tensorFieldValue(tensorDataType); tensorFieldValue = createTensor({ {{{"x", "4"}, {"y", "5"}}, 7} }, {"x", "y"}); Document doc(getDocType(), DocumentId("doc:scheme:")); diff --git a/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp b/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp index a1bce7174bf..b17f13eaea0 100644 --- a/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp +++ b/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp @@ -1,9 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/persistence/spi/result.h> +#include <vespa/document/base/exceptions.h> +#include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/update/assignvalueupdate.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/document/update/documentupdate.h> +#include <vespa/eval/tensor/tensor.h> +#include <vespa/eval/tensor/test/test_utils.h> #include <vespa/searchcore/proton/bucketdb/bucketdbhandler.h> #include <vespa/searchcore/proton/test/bucketfactory.h> #include <vespa/searchcore/proton/common/feedtoken.h> @@ -42,6 +46,8 @@ using document::DocumentType; using document::DocumentTypeRepo; using document::DocumentUpdate; using document::GlobalId; +using document::TensorDataType; +using document::TensorFieldValue; using search::IDestructorCallback; using search::SerialNum; using search::index::schema::CollectionType; @@ -58,6 +64,10 @@ using vespalib::ThreadStackExecutor; using vespalib::ThreadStackExecutorBase; using vespalib::makeClosure; using vespalib::makeTask; +using vespalib::eval::TensorSpec; +using vespalib::eval::ValueType; +using vespalib::tensor::test::makeTensor; +using vespalib::tensor::Tensor; using namespace proton; using namespace search::index; @@ -281,6 +291,8 @@ SchemaContext::SchemaContext() : schema(new Schema()), builder() { + schema->addAttributeField(Schema::AttributeField("tensor", DataType::TENSOR, CollectionType::SINGLE)); + schema->addAttributeField(Schema::AttributeField("tensor2", DataType::TENSOR, CollectionType::SINGLE)); addField("i1"); } @@ -312,6 +324,8 @@ struct TwoFieldsSchemaContext : public SchemaContext { } }; +TensorDataType tensor1DType(ValueType::from_spec("tensor(x{})")); + struct UpdateContext { DocumentUpdate::SP update; BucketId bucketId; @@ -324,7 +338,19 @@ struct UpdateContext { const auto &docType = update->getType(); const auto &field = docType.getField(fieldName); auto fieldValue = field.createValue(); - fieldValue->assign(document::StringFieldValue("new value")); + if (fieldName == "tensor") { + dynamic_cast<TensorFieldValue &>(*fieldValue) = + makeTensor<Tensor>(TensorSpec("tensor(x{},y{})"). + add({{"x","8"},{"y","9"}}, 11)); + } else if (fieldName == "tensor2") { + auto tensorFieldValue = std::make_unique<TensorFieldValue>(tensor1DType); + *tensorFieldValue = + makeTensor<Tensor>(TensorSpec("tensor(x{})"). + add({{"x","8"}}, 11)); + fieldValue = std::move(tensorFieldValue); + } else { + fieldValue->assign(document::StringFieldValue("new value")); + } document::AssignValueUpdate assignValueUpdate(*fieldValue); document::FieldUpdate fieldUpdate(field); fieldUpdate.addUpdate(assignValueUpdate); @@ -711,8 +737,13 @@ checkUpdate(FeedHandlerFixture &f, SchemaContext &schemaContext, if (expectReject) { TEST_DO(f.feedView.checkCounts(0, 0u, 0, 0u)); EXPECT_EQUAL(Result::TRANSIENT_ERROR, token.getResult()->getErrorCode()); - EXPECT_EQUAL("Update operation rejected for document 'id:test:searchdocument::foo' of type 'searchdocument': 'Field not found'", - token.getResult()->getErrorMessage()); + if (fieldName == "tensor2") { + EXPECT_EQUAL("Update operation rejected for document 'id:test:searchdocument::foo' of type 'searchdocument': 'Wrong tensor type: Field tensor type is 'tensor(x{},y{})' but tensor type is 'tensor(x{})''", + token.getResult()->getErrorMessage()); + } else { + EXPECT_EQUAL("Update operation rejected for document 'id:test:searchdocument::foo' of type 'searchdocument': 'Field not found'", + token.getResult()->getErrorMessage()); + } } else { if (existing) { TEST_DO(f.feedView.checkCounts(1, 16u, 0, 0u)); @@ -758,6 +789,18 @@ TEST_F("require that update with different document type repo can be rejected, p checkUpdate(f, schema, "i2", true, false); } +TEST_F("require that tensor update with correct tensor type works", FeedHandlerFixture) +{ + TwoFieldsSchemaContext schema; + checkUpdate(f, schema, "tensor", false, true); +} + +TEST_F("require that tensor update with wrong tensor type fails", FeedHandlerFixture) +{ + TwoFieldsSchemaContext schema; + checkUpdate(f, schema, "tensor2", true, true); +} + TEST_F("require that put with different document type repo is ok", FeedHandlerFixture) { TwoFieldsSchemaContext schema; diff --git a/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp b/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp index 42790254e53..dc95f3ddc04 100644 --- a/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp +++ b/searchcore/src/vespa/searchcore/proton/persistenceengine/persistenceengine.cpp @@ -394,6 +394,12 @@ PersistenceEngine::update(const Bucket& b, Timestamp t, const DocumentUpdate::SP make_string("Update operation rejected for document '%s' of type '%s'.", upd->getId().toString().c_str(), e.getDocumentTypeName().c_str())); + } catch (document::WrongTensorTypeException &e) { + return UpdateResult(Result::TRANSIENT_ERROR, + make_string("Update operation rejected for document '%s' of type '%s': 'Wrong tensor type: %s'", + upd->getId().toString().c_str(), + upd->getType().getName().c_str(), + e.getMessage().c_str())); } std::shared_lock<std::shared_timed_mutex> rguard(_rwMutex); DocTypeName docType(upd->getType()); diff --git a/searchcore/src/vespa/searchcore/proton/server/feedhandler.cpp b/searchcore/src/vespa/searchcore/proton/server/feedhandler.cpp index a0d62f2052d..fd38b74f584 100644 --- a/searchcore/src/vespa/searchcore/proton/server/feedhandler.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/feedhandler.cpp @@ -539,6 +539,14 @@ FeedHandler::considerUpdateOperationForRejection(FeedToken &token, UpdateOperati token->setResult(make_unique<UpdateResult>(Result::TRANSIENT_ERROR, message), false); token->fail(); return true; + } catch (document::WrongTensorTypeException &e) { + auto message = make_string("Update operation rejected for document '%s' of type '%s': 'Wrong tensor type: %s'", + update.getId().toString().c_str(), + _docTypeName.toString().c_str(), + e.getMessage().c_str()); + token->setResult(make_unique<UpdateResult>(Result::TRANSIENT_ERROR, message), false); + token->fail(); + return true; } } return false; diff --git a/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp b/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp index 70d1f3de504..fb02ee32b98 100644 --- a/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp +++ b/searchlib/src/vespa/searchlib/index/doctypebuilder.cpp @@ -10,7 +10,7 @@ using namespace document; namespace search::index { namespace { -TensorDataType tensorDataType; +TensorDataType tensorDataType(vespalib::eval::ValueType::from_spec("tensor(x{}, y{})")); const DataType *convert(Schema::DataType type) { switch (type) { @@ -254,6 +254,7 @@ struct TypeCache { return types.find(key)->second; } }; + } // namespace document::DocumenttypesConfig DocTypeBuilder::makeConfig() const { @@ -296,8 +297,12 @@ document::DocumenttypesConfig DocTypeBuilder::makeConfig() const { continue; // taken as index field const DataType *primitiveType = convert(field.getDataType()); - header_struct.addField(field.getName(), type_cache.getType( + if (primitiveType->getId() == DataType::T_TENSOR) { + header_struct.addTensorField(field.getName(), dynamic_cast<const TensorDataType &>(*primitiveType).getTensorType().to_spec()); + } else { + header_struct.addField(field.getName(), type_cache.getType( primitiveType->getId(), field.getCollectionType())); + } usedFields.insert(field.getName()); } @@ -307,8 +312,12 @@ document::DocumenttypesConfig DocTypeBuilder::makeConfig() const { if (usf != usedFields.end()) continue; // taken as index field or attribute field const DataType *primitiveType(convert(field.getDataType())); - header_struct.addField(field.getName(), type_cache.getType( + if (primitiveType->getId() == DataType::T_TENSOR) { + header_struct.addTensorField(field.getName(), dynamic_cast<const TensorDataType &>(*primitiveType).getTensorType().to_spec()); + } else { + header_struct.addField(field.getName(), type_cache.getType( primitiveType->getId(), field.getCollectionType())); + } usedFields.insert(field.getName()); } |