diff options
10 files changed, 236 insertions, 1 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index e7283849178..c7660e5d527 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -12,6 +12,7 @@ #include <vespa/document/update/fieldupdate.h> #include <vespa/document/update/mapvalueupdate.h> #include <vespa/document/update/removevalueupdate.h> +#include <vespa/document/update/tensoraddupdate.h> #include <vespa/document/update/tensormodifyupdate.h> #include <vespa/document/update/valueupdate.h> #include <vespa/document/serialization/vespadocumentserializer.h> @@ -60,6 +61,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture { void testMapValueUpdate(); void testTensorAssignUpdate(); void testTensorClearUpdate(); + void testTensorAddUpdate(); void testTensorModifyUpdate(); void testThatDocumentUpdateFlagsIsWorking(); void testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50(); @@ -89,6 +91,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture { CPPUNIT_TEST(testMapValueUpdate); CPPUNIT_TEST(testTensorAssignUpdate); CPPUNIT_TEST(testTensorClearUpdate); + CPPUNIT_TEST(testTensorAddUpdate); CPPUNIT_TEST(testTensorModifyUpdate); CPPUNIT_TEST(testThatDocumentUpdateFlagsIsWorking); CPPUNIT_TEST(testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50); @@ -181,6 +184,14 @@ FieldValue::UP createTensorFieldValueWith2Cells() { return std::move(fv); } +std::unique_ptr<TensorAddUpdate> createTensorAddUpdate() { + auto tensorFieldValue(std::make_unique<TensorFieldValue>()); + *tensorFieldValue = createTensor({ {{{"x", "8"}, {"y", "8"}}, 2}, + {{{"x", "8"}, {"y", "9"}}, 2} }, {"x", "y"}); + auto update = std::make_unique<TensorAddUpdate>(std::move(tensorFieldValue)); + return update; +} + std::unique_ptr<TensorModifyUpdate> createTensorModifyUpdate() { auto tensorFieldValue(std::make_unique<TensorFieldValue>()); *tensorFieldValue = createTensor({ {{{"x", "8"}, {"y", "9"}}, 2} }, {"x", "y"}); @@ -953,6 +964,27 @@ DocumentUpdateTest::testTensorClearUpdate() } void +DocumentUpdateTest::testTensorAddUpdate() +{ + TestDocMan docMan; + Document::UP doc(docMan.createDocument()); + Document updated(*doc); + auto oldTensor = createTensorFieldValueWith2Cells(); + updated.setValue(updated.getField("tensor"), *oldTensor); + CPPUNIT_ASSERT(*doc != updated); + testValueUpdate(*createTensorAddUpdate(), *DataType::TENSOR); + DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId()); + upd.addUpdate(FieldUpdate(upd.getType().getField("tensor")).addUpdate(*createTensorAddUpdate())); + upd.applyTo(updated); + FieldValue::UP fval(updated.getValue("tensor")); + CPPUNIT_ASSERT(fval); + auto &tensor = asTensor(*fval); + // Note: Placeholder value for now + auto expectedUpdatedTensor = createTensorWith2Cells(); + CPPUNIT_ASSERT(tensor.equals(*expectedUpdatedTensor)); +} + +void DocumentUpdateTest::testTensorModifyUpdate() { TestDocMan docMan; diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp index 7960cc7934a..8364b560198 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp @@ -581,4 +581,17 @@ VespaDocumentSerializer::visit(const TensorModifyUpdate &value) write(value); } +void +VespaDocumentSerializer::write(const TensorAddUpdate &value) +{ + _stream << TensorAddUpdate::classId; + write(value.getTensor()); +} + +void +VespaDocumentSerializer::visit(const TensorAddUpdate &value) +{ + write(value); +} + } // namespace document diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.h b/document/src/vespa/document/serialization/vespadocumentserializer.h index b2885d10968..08fe7ccdad9 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.h +++ b/document/src/vespa/document/serialization/vespadocumentserializer.h @@ -75,6 +75,7 @@ private: void write(const AssignFieldPathUpdate &value); void write(const RemoveFieldPathUpdate &value); void write(const TensorModifyUpdate &value); + void write(const TensorAddUpdate &value); void visit(const DocumentUpdate &value) override { writeHEAD(value); } void visit(const FieldUpdate &value) override { write(value); } @@ -88,6 +89,7 @@ private: void visit(const AssignFieldPathUpdate &value) override { write(value); } void visit(const RemoveFieldPathUpdate &value) override { write(value); } void visit(const TensorModifyUpdate &value) override; + void visit(const TensorAddUpdate &value) override; void visit(const AnnotationReferenceFieldValue &value) override { write(value); } void visit(const ArrayFieldValue &value) override { write(value); } diff --git a/document/src/vespa/document/update/CMakeLists.txt b/document/src/vespa/document/update/CMakeLists.txt index fc3a6fb5495..34f539ee4aa 100644 --- a/document/src/vespa/document/update/CMakeLists.txt +++ b/document/src/vespa/document/update/CMakeLists.txt @@ -13,6 +13,7 @@ vespa_add_library(document_updates OBJECT mapvalueupdate.cpp removefieldpathupdate.cpp removevalueupdate.cpp + tensoraddupdate.cpp tensormodifyupdate.cpp valueupdate.cpp DEPENDS diff --git a/document/src/vespa/document/update/tensoraddupdate.cpp b/document/src/vespa/document/update/tensoraddupdate.cpp new file mode 100644 index 00000000000..eb708d9f651 --- /dev/null +++ b/document/src/vespa/document/update/tensoraddupdate.cpp @@ -0,0 +1,142 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "tensoraddupdate.h" +#include <vespa/document/base/exceptions.h> +#include <vespa/document/base/field.h> +#include <vespa/document/fieldvalue/document.h> +#include <vespa/document/fieldvalue/tensorfieldvalue.h> +#include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/document/util/serializableexceptions.h> +#include <vespa/eval/tensor/tensor.h> +#include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/stllike/asciistream.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/xmlstream.h> +#include <ostream> + +using vespalib::IllegalArgumentException; +using vespalib::IllegalStateException; +using vespalib::tensor::Tensor; +using vespalib::make_string; + +namespace document { + +IMPLEMENT_IDENTIFIABLE(TensorAddUpdate, ValueUpdate); + +TensorAddUpdate::TensorAddUpdate() + : _tensor() +{ +} + +TensorAddUpdate::TensorAddUpdate(const TensorAddUpdate &rhs) + : _tensor(rhs._tensor->clone()) +{ +} + +TensorAddUpdate::TensorAddUpdate(std::unique_ptr<TensorFieldValue> &&tensor) + : _tensor(std::move(tensor)) +{ +} + +TensorAddUpdate::~TensorAddUpdate() = default; + +TensorAddUpdate & +TensorAddUpdate::operator=(const TensorAddUpdate &rhs) +{ + _tensor.reset(rhs._tensor->clone()); + return *this; +} + +TensorAddUpdate & +TensorAddUpdate::operator=(TensorAddUpdate &&rhs) +{ + _tensor = std::move(rhs._tensor); + return *this; +} + +bool +TensorAddUpdate::operator==(const ValueUpdate &other) const +{ + if (other.getClass().id() != TensorAddUpdate::classId) { + return false; + } + const TensorAddUpdate& o(static_cast<const TensorAddUpdate&>(other)); + if (*_tensor != *o._tensor) { + return false; + } + return true; +} + + +void +TensorAddUpdate::checkCompatibility(const Field& field) const +{ + if (field.getDataType() != *DataType::TENSOR) { + throw IllegalArgumentException(make_string( + "Can not perform tensor add update on non-tensor field '%s'.", + field.getName().data()), VESPA_STRLOC); + } +} + +std::unique_ptr<Tensor> +TensorAddUpdate::applyTo(const Tensor &tensor) const +{ + return tensor.clone(); +} + +bool +TensorAddUpdate::applyTo(FieldValue& value) const +{ + if (value.inherits(TensorFieldValue::classId)) { + TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); + auto &oldTensor = tensorFieldValue.getAsTensorPtr(); + auto newTensor = applyTo(*oldTensor); + if (newTensor) { + tensorFieldValue = std::move(newTensor); + } + } else { + std::string err = make_string( + "Unable to perform a tensor add update on a \"%s\" field " + "value.", value.getClass().name()); + throw IllegalStateException(err, VESPA_STRLOC); + } + return true; +} + +void +TensorAddUpdate::printXml(XmlOutputStream& xos) const +{ + xos << "{TensorAddUpdate::printXml not yet implemented}"; +} + +void +TensorAddUpdate::print(std::ostream& out, bool verbose, const std::string& indent) const +{ + (void) verbose; + (void) indent; + out << "{TensorAddUpdate::print not yet implemented}"; +} + +void +TensorAddUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream & stream) +{ + auto tensor = type.createFieldValue(); + if (tensor->inherits(TensorFieldValue::classId)) { + _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); + } else { + std::string err = make_string( + "Expected tensor field value, got a \"%s\" field " + "value.", tensor->getClass().name()); + throw IllegalStateException(err, VESPA_STRLOC); + } + VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion()); + deserializer.read(*_tensor); +} + +TensorAddUpdate* +TensorAddUpdate::clone() const +{ + return new TensorAddUpdate(*this); +} + +} diff --git a/document/src/vespa/document/update/tensoraddupdate.h b/document/src/vespa/document/update/tensoraddupdate.h new file mode 100644 index 00000000000..52e44ea33f3 --- /dev/null +++ b/document/src/vespa/document/update/tensoraddupdate.h @@ -0,0 +1,40 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "valueupdate.h" + +namespace vespalib::tensor { struct Tensor; } + +namespace document { + +class TensorFieldValue; + +/* + * An update used to add cells to a sparse tensor (has only mapped dimensions). + * + * The cells to add are contained in a sparse tensor as well. + */ +class TensorAddUpdate : public ValueUpdate { + std::unique_ptr<TensorFieldValue> _tensor; + + TensorAddUpdate(); + TensorAddUpdate(const TensorAddUpdate &rhs); + ACCEPT_UPDATE_VISITOR; +public: + TensorAddUpdate(std::unique_ptr<TensorFieldValue> &&tensor); + ~TensorAddUpdate() override; + TensorAddUpdate &operator=(const TensorAddUpdate &rhs); + TensorAddUpdate &operator=(TensorAddUpdate &&rhs); + bool operator==(const ValueUpdate &other) const override; + const TensorFieldValue &getTensor() const { return *_tensor; } + void checkCompatibility(const Field &field) const override; + std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const; + bool applyTo(FieldValue &value) const override; + void printXml(XmlOutputStream &xos) const override; + void print(std::ostream &out, bool verbose, const std::string &indent) const override; + void deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) override; + TensorAddUpdate* clone() const override; + + DECLARE_IDENTIFIABLE(TensorAddUpdate); +}; + +} diff --git a/document/src/vespa/document/update/updates.h b/document/src/vespa/document/update/updates.h index 013ce70b6b6..1609c5bc3a3 100644 --- a/document/src/vespa/document/update/updates.h +++ b/document/src/vespa/document/update/updates.h @@ -11,4 +11,5 @@ #include "mapvalueupdate.h" #include "removevalueupdate.h" #include "tensormodifyupdate.h" +#include "tensoraddupdate.h" diff --git a/document/src/vespa/document/update/updatevisitor.h b/document/src/vespa/document/update/updatevisitor.h index e6291f90f69..f41e985f7c8 100644 --- a/document/src/vespa/document/update/updatevisitor.h +++ b/document/src/vespa/document/update/updatevisitor.h @@ -15,6 +15,7 @@ class MapValueUpdate; class AddFieldPathUpdate; class AssignFieldPathUpdate; class RemoveFieldPathUpdate; +class TensorAddUpdate; class TensorModifyUpdate; struct UpdateVisitor { @@ -32,6 +33,7 @@ struct UpdateVisitor { virtual void visit(const AssignFieldPathUpdate &value) = 0; virtual void visit(const RemoveFieldPathUpdate &value) = 0; virtual void visit(const TensorModifyUpdate &value) = 0; + virtual void visit(const TensorAddUpdate &value) = 0; }; #define ACCEPT_UPDATE_VISITOR void accept(UpdateVisitor & visitor) const override { visitor.visit(*this); } diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h index ceb711074f4..0e15943f8e4 100644 --- a/document/src/vespa/document/update/valueupdate.h +++ b/document/src/vespa/document/update/valueupdate.h @@ -54,7 +54,8 @@ public: Clear = IDENTIFIABLE_CLASSID(ClearValueUpdate), Map = IDENTIFIABLE_CLASSID(MapValueUpdate), Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate), - TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate) + TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate), + TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate) }; ValueUpdate() diff --git a/document/src/vespa/document/util/identifiableid.h b/document/src/vespa/document/util/identifiableid.h index 80d43369a0a..c8859cedb2e 100644 --- a/document/src/vespa/document/util/identifiableid.h +++ b/document/src/vespa/document/util/identifiableid.h @@ -69,6 +69,7 @@ #define CID_RemoveFieldPathUpdate DOCUMENT_CID(88) #define CID_TensorModifyUpdate DOCUMENT_CID(100) +#define CID_TensorAddUpdate DOCUMENT_CID(101) #define CID_document_DocumentUpdate DOCUMENT_CID(999) |