diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-15 12:06:40 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-02-15 12:06:40 +0000 |
commit | 9600398492f6f30013669740a46da07b2b7b94cf (patch) | |
tree | c57a2d3e022e4fe56a52023995777927d1b37edf /document | |
parent | 4abb67ad861c6c0c1aa1fa231b0f9412a0309578 (diff) |
Move common code to test fixture and setup tensors used in test more explicitly.
Diffstat (limited to 'document')
-rw-r--r-- | document/src/tests/documentupdatetestcase.cpp | 198 | ||||
-rw-r--r-- | document/src/vespa/document/datatype/tensor_data_type.cpp | 2 | ||||
-rw-r--r-- | document/src/vespa/document/datatype/tensor_data_type.h | 3 |
3 files changed, 106 insertions, 97 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index e666876c357..e1cc3991b38 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -160,77 +160,8 @@ void testRoundtripSerialize(const UpdateType& update, const DataType &type) { } } -const std::string tensorType2DMapped = "tensor(x{},y{})"; -TensorDataType tensorDataType2DMapped(vespalib::eval::ValueType::from_spec(tensorType2DMapped)); - -std::unique_ptr<Tensor> -makeTensor(const TensorSpec &spec) -{ - auto result = DefaultTensorEngine::ref().from_spec(spec); - return std::unique_ptr<Tensor>(dynamic_cast<Tensor*>(result.release())); -} - -std::unique_ptr<TensorFieldValue> -makeTensorFieldValue(const TensorSpec &spec) -{ - auto tensor = makeTensor(spec); - auto result = std::make_unique<TensorFieldValue>(tensorDataType2DMapped); - result->assignDeserialized(std::move(tensor)); - return result; -} - -FieldValue::UP createTensorFieldValue() { - return makeTensorFieldValue(TensorSpec(tensorType2DMapped) - .add({{"x", "8"}, {"y", "9"}}, 11)); -} - -std::unique_ptr<Tensor> createTensorWith2Cells() { - return makeTensor(TensorSpec(tensorType2DMapped) - .add({{"x", "8"}, {"y", "9"}}, 11) - .add({{"x", "9"}, {"y", "9"}}, 11)); -} - -std::unique_ptr<Tensor> createExpectedUpdatedTensorWith2Cells() { - return makeTensor(TensorSpec(tensorType2DMapped) - .add({{"x", "8"}, {"y", "9"}}, 2) - .add({{"x", "9"}, {"y", "9"}}, 11)); -} - -std::unique_ptr<Tensor> createExpectedAddUpdatedTensorWith3Cells() { - return makeTensor(TensorSpec(tensorType2DMapped) - .add({{"x", "8"}, {"y", "8"}}, 2) - .add({{"x", "8"}, {"y", "9"}}, 2) - .add({{"x", "9"}, {"y", "9"}}, 11)); -} - -FieldValue::UP createTensorFieldValueWith2Cells() { - auto result = std::make_unique<TensorFieldValue>(tensorDataType2DMapped); - result->assignDeserialized(createTensorWith2Cells()); - return result; -} - -std::unique_ptr<TensorAddUpdate> createTensorAddUpdate() { - auto tensorFieldValue = makeTensorFieldValue(TensorSpec(tensorType2DMapped) - .add({{"x", "8"}, {"y", "8"}}, 2) - .add({{"x", "8"}, {"y", "9"}}, 2)); - return std::make_unique<TensorAddUpdate>(std::move(tensorFieldValue)); -} - -std::unique_ptr<TensorModifyUpdate> createTensorModifyUpdate() { - auto tensorFieldValue = makeTensorFieldValue(TensorSpec(tensorType2DMapped) - .add({{"x", "8"}, {"y", "9"}}, 2)); - return std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, std::move(tensorFieldValue)); -} - -const Tensor &asTensor(const FieldValue &fieldValue) { - auto &tensorFieldValue = dynamic_cast<const TensorFieldValue &>(fieldValue); - auto &tensor = tensorFieldValue.getAsTensorPtr(); - CPPUNIT_ASSERT(tensor); - return *tensor; } -} // namespace - void DocumentUpdateTest::testSimpleUsage() { DocumenttypesConfigBuilderHelper builder; @@ -952,37 +883,82 @@ DocumentUpdateTest::testMapValueUpdate() CPPUNIT_ASSERT(fv4->find(StringFieldValue("apple")) == fv4->end()); } +std::unique_ptr<Tensor> +makeTensor(const TensorSpec &spec) +{ + auto result = DefaultTensorEngine::ref().from_spec(spec); + return std::unique_ptr<Tensor>(dynamic_cast<Tensor*>(result.release())); +} + +std::unique_ptr<TensorFieldValue> +makeTensorFieldValue(const TensorSpec &spec, const TensorDataType &dataType) +{ + auto tensor = makeTensor(spec); + auto result = std::make_unique<TensorFieldValue>(dataType); + result->assignDeserialized(std::move(tensor)); + return result; +} + +const Tensor &asTensor(const FieldValue &fieldValue) { + auto &tensorFieldValue = dynamic_cast<const TensorFieldValue &>(fieldValue); + auto &tensor = tensorFieldValue.getAsTensorPtr(); + CPPUNIT_ASSERT(tensor); + return *tensor; +} + struct TensorUpdateFixture { TestDocMan docMan; Document::UP emptyDoc; Document updatedDoc; vespalib::string fieldName; + vespalib::string tensorType; + TensorDataType tensorDataType; - TensorUpdateFixture(const vespalib::string &fieldName_ = "tensor") + TensorUpdateFixture(const vespalib::string &fieldName_ = "tensor", + const vespalib::string &tensorType_ = "tensor(x{},y{})") : docMan(), emptyDoc(docMan.createDocument()), updatedDoc(*emptyDoc), - fieldName(fieldName_) + fieldName(fieldName_), + tensorType(tensorType_), + tensorDataType(vespalib::eval::ValueType::from_spec(tensorType)) { CPPUNIT_ASSERT(!emptyDoc->getValue(fieldName)); } ~TensorUpdateFixture() {} - void applyUpdate(const ValueUpdate &update) { - DocumentUpdate docUpdate(docMan.getTypeRepo(), *emptyDoc->getDataType(), emptyDoc->getId()); - docUpdate.addUpdate(FieldUpdate(docUpdate.getType().getField(fieldName)).addUpdate(update)); - docUpdate.applyTo(updatedDoc); + TensorSpec spec() { + return TensorSpec(tensorType); } FieldValue::UP getTensor() { return updatedDoc.getValue(fieldName); } - void setTensor(const FieldValue &tensor) { - updatedDoc.setValue(updatedDoc.getField(fieldName), tensor); + void setTensor(const TensorFieldValue &tensorValue) { + updatedDoc.setValue(updatedDoc.getField(fieldName), tensorValue); assertDocumentUpdated(); } + void setTensor(const TensorSpec &spec) { + setTensor(*makeTensor(spec)); + } + + std::unique_ptr<TensorFieldValue> makeTensor(const TensorSpec &spec) { + return makeTensorFieldValue(spec, tensorDataType); + } + + std::unique_ptr<TensorFieldValue> makeBaselineTensor() { + return makeTensor(spec().add({{"x", "a"}, {"y", "a"}}, 2) + .add({{"x", "a"}, {"y", "b"}}, 3)); + } + + void applyUpdate(const ValueUpdate &update) { + DocumentUpdate docUpdate(docMan.getTypeRepo(), *emptyDoc->getDataType(), emptyDoc->getId()); + docUpdate.addUpdate(FieldUpdate(docUpdate.getType().getField(fieldName)).addUpdate(update)); + docUpdate.applyTo(updatedDoc); + } + void assertDocumentUpdated() { CPPUNIT_ASSERT(*emptyDoc != updatedDoc); } @@ -991,17 +967,32 @@ struct TensorUpdateFixture { CPPUNIT_ASSERT(*emptyDoc == updatedDoc); } - void assertTensor(const FieldValue &expTensorValue) { + void assertTensor(const TensorFieldValue &expTensorValue) { auto actTensorValue = getTensor(); CPPUNIT_ASSERT(actTensorValue); CPPUNIT_ASSERT(*actTensorValue == expTensorValue); + auto &actTensor = asTensor(*actTensorValue); + auto &expTensor = asTensor(expTensorValue); + CPPUNIT_ASSERT(actTensor == expTensor); } - void assertTensor(const Tensor &expTensor) { - auto actTensorValue = getTensor(); - CPPUNIT_ASSERT(actTensorValue); - auto &actTensor = asTensor(*actTensorValue); - CPPUNIT_ASSERT(actTensor.equals(expTensor)); + void assertTensor(const TensorSpec &expSpec) { + auto expTensor = makeTensor(expSpec); + assertTensor(*expTensor); + } + + void assertApplyUpdate(const TensorSpec& initialTensor, + const ValueUpdate& update, + const TensorSpec& expTensor) { + setTensor(initialTensor); + applyUpdate(update); + assertDocumentUpdated(); + assertTensor(expTensor); + } + + template <typename ValueUpdateType> + void assertRoundtripSerialize(const ValueUpdateType &valueUpdate) { + testRoundtripSerialize(valueUpdate, tensorDataType); } }; @@ -1010,7 +1001,7 @@ void DocumentUpdateTest::tensor_assign_update_can_be_applied() { TensorUpdateFixture f; - auto newTensor = createTensorFieldValue(); + auto newTensor = f.makeBaselineTensor(); f.applyUpdate(AssignValueUpdate(*newTensor)); f.assertDocumentUpdated(); f.assertTensor(*newTensor); @@ -1020,7 +1011,7 @@ void DocumentUpdateTest::tensor_clear_update_can_be_applied() { TensorUpdateFixture f; - f.setTensor(*createTensorFieldValue()); + f.setTensor(*f.makeBaselineTensor()); f.applyUpdate(ClearValueUpdate()); f.assertDocumentNotUpdated(); CPPUNIT_ASSERT(!f.getTensor()); @@ -1030,38 +1021,53 @@ void DocumentUpdateTest::tensor_add_update_can_be_applied() { TensorUpdateFixture f; - f.setTensor(*createTensorFieldValueWith2Cells()); - f.applyUpdate(*createTensorAddUpdate()); - f.assertDocumentUpdated(); - f.assertTensor(*createExpectedAddUpdatedTensorWith3Cells()); + f.assertApplyUpdate(f.spec().add({{"x", "a"}, {"y", "a"}}, 2) + .add({{"x", "a"}, {"y", "b"}}, 3), + + TensorAddUpdate(f.makeTensor(f.spec().add({{"x", "a"}, {"y", "b"}}, 5) + .add({{"x", "a"}, {"y", "c"}}, 7))), + + f.spec().add({{"x", "a"}, {"y", "a"}}, 2) + .add({{"x", "a"}, {"y", "b"}}, 5) + .add({{"x", "a"}, {"y", "c"}}, 7)); } void DocumentUpdateTest::tensor_modify_update_can_be_applied() { TensorUpdateFixture f; - f.setTensor(*createTensorFieldValueWith2Cells()); - f.applyUpdate(*createTensorModifyUpdate()); - f.assertDocumentUpdated(); - f.assertTensor(*createExpectedUpdatedTensorWith2Cells()); + f.assertApplyUpdate(f.spec().add({{"x", "a"}, {"y", "a"}}, 2) + .add({{"x", "a"}, {"y", "b"}}, 3), + + TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, + f.makeTensor(f.spec().add({{"x", "a"}, {"y", "b"}}, 5) + .add({{"x", "a"}, {"y", "c"}}, 7))), + + f.spec().add({{"x", "a"}, {"y", "a"}}, 2) + .add({{"x", "a"}, {"y", "b"}}, 5)); } void DocumentUpdateTest::tensor_assign_update_can_be_roundtrip_serialized() { - testRoundtripSerialize(AssignValueUpdate(*createTensorFieldValue()), tensorDataType2DMapped); + TensorUpdateFixture f; + f.assertRoundtripSerialize(AssignValueUpdate(*f.makeBaselineTensor())); } void DocumentUpdateTest::tensor_add_update_can_be_roundtrip_serialized() { - testRoundtripSerialize(*createTensorAddUpdate(), tensorDataType2DMapped); + TensorUpdateFixture f; + f.assertRoundtripSerialize(TensorAddUpdate(f.makeBaselineTensor())); } void DocumentUpdateTest::tensor_modify_update_can_be_roundtrip_serialized() { - testRoundtripSerialize(*createTensorModifyUpdate(), tensorDataType2DMapped); + TensorUpdateFixture f; + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor())); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor())); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MUL, f.makeBaselineTensor())); } diff --git a/document/src/vespa/document/datatype/tensor_data_type.cpp b/document/src/vespa/document/datatype/tensor_data_type.cpp index 0a3fdf71f3b..d3d747c045f 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.cpp +++ b/document/src/vespa/document/datatype/tensor_data_type.cpp @@ -22,6 +22,8 @@ TensorDataType::TensorDataType(ValueType tensorType) { } +TensorDataType::~TensorDataType() = default; + FieldValue::UP TensorDataType::createFieldValue() const { diff --git a/document/src/vespa/document/datatype/tensor_data_type.h b/document/src/vespa/document/datatype/tensor_data_type.h index a5f92176074..acb243e96ff 100644 --- a/document/src/vespa/document/datatype/tensor_data_type.h +++ b/document/src/vespa/document/datatype/tensor_data_type.h @@ -14,7 +14,8 @@ class TensorDataType : public PrimitiveDataType { public: TensorDataType(); TensorDataType(vespalib::eval::ValueType tensorType); - + ~TensorDataType(); + std::unique_ptr<FieldValue> createFieldValue() const override; TensorDataType* clone() const override; void print(std::ostream&, bool verbose, const std::string& indent) const override; |