diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-15 13:13:43 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-02-15 13:13:43 +0000 |
commit | 81999faf9b314681419ac0e00ff921efc5566d90 (patch) | |
tree | aee3a054d3240e0be5295d21f9408d563c478596 /document | |
parent | f0352c8c527ed03c7bc82ba22aaeee14d38ef516 (diff) |
Implement skeleton of TensorRemoveUpdate with support for (de)-serialization.
Diffstat (limited to 'document')
9 files changed, 213 insertions, 11 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index ddc209234c5..b351299f2d1 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -17,6 +17,7 @@ #include <vespa/document/update/removevalueupdate.h> #include <vespa/document/update/tensor_add_update.h> #include <vespa/document/update/tensor_modify_update.h> +#include <vespa/document/update/tensor_remove_update.h> #include <vespa/document/update/valueupdate.h> #include <vespa/document/util/bytebuffer.h> #include <vespa/eval/tensor/default_tensor_engine.h> @@ -63,6 +64,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture { void tensor_modify_update_can_be_applied(); void tensor_assign_update_can_be_roundtrip_serialized(); void tensor_add_update_can_be_roundtrip_serialized(); + void tensor_remove_update_can_be_roundtrip_serialized(); void tensor_modify_update_can_be_roundtrip_serialized(); void testThatDocumentUpdateFlagsIsWorking(); void testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50(); @@ -96,6 +98,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture { CPPUNIT_TEST(tensor_modify_update_can_be_applied); CPPUNIT_TEST(tensor_assign_update_can_be_roundtrip_serialized); CPPUNIT_TEST(tensor_add_update_can_be_roundtrip_serialized); + CPPUNIT_TEST(tensor_remove_update_can_be_roundtrip_serialized); CPPUNIT_TEST(tensor_modify_update_can_be_roundtrip_serialized); CPPUNIT_TEST(testThatDocumentUpdateFlagsIsWorking); CPPUNIT_TEST(testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50); @@ -1062,6 +1065,13 @@ DocumentUpdateTest::tensor_add_update_can_be_roundtrip_serialized() } void +DocumentUpdateTest::tensor_remove_update_can_be_roundtrip_serialized() +{ + TensorUpdateFixture f; + f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor())); +} + +void DocumentUpdateTest::tensor_modify_update_can_be_roundtrip_serialized() { TensorUpdateFixture f; diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp index 8364b560198..0d6703b4e97 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp @@ -1,9 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "vespadocumentserializer.h" #include "annotationserializer.h" #include "slime_output_to_vector.h" #include "util.h" +#include "vespadocumentserializer.h" +#include <vespa/document/datatype/weightedsetdatatype.h> +#include <vespa/document/fieldset/fieldsets.h> #include <vespa/document/fieldvalue/annotationreferencefieldvalue.h> #include <vespa/document/fieldvalue/arrayfieldvalue.h> #include <vespa/document/fieldvalue/boolfieldvalue.h> @@ -16,20 +18,18 @@ #include <vespa/document/fieldvalue/mapfieldvalue.h> #include <vespa/document/fieldvalue/predicatefieldvalue.h> #include <vespa/document/fieldvalue/rawfieldvalue.h> +#include <vespa/document/fieldvalue/referencefieldvalue.h> #include <vespa/document/fieldvalue/shortfieldvalue.h> #include <vespa/document/fieldvalue/stringfieldvalue.h> -#include <vespa/document/fieldvalue/weightedsetfieldvalue.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> -#include <vespa/document/fieldvalue/referencefieldvalue.h> -#include <vespa/document/datatype/weightedsetdatatype.h> -#include <vespa/document/update/updates.h> +#include <vespa/document/fieldvalue/weightedsetfieldvalue.h> #include <vespa/document/update/fieldpathupdates.h> +#include <vespa/document/update/updates.h> #include <vespa/document/util/bytebuffer.h> -#include <vespa/document/fieldset/fieldsets.h> +#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/vespalib/data/databuffer.h> #include <vespa/vespalib/data/slime/binary_format.h> #include <vespa/vespalib/objects/nbostream.h> -#include <vespa/vespalib/data/databuffer.h> -#include <vespa/eval/tensor/serialization/typed_binary_format.h> #include <vespa/vespalib/util/compressor.h> using std::make_pair; @@ -594,4 +594,17 @@ VespaDocumentSerializer::visit(const TensorAddUpdate &value) write(value); } +void +VespaDocumentSerializer::write(const TensorRemoveUpdate &value) +{ + _stream << TensorRemoveUpdate::classId; + write(value.getTensor()); +} + +void +VespaDocumentSerializer::visit(const TensorRemoveUpdate &value) +{ + write(value); +} + } // namespace document diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.h b/document/src/vespa/document/serialization/vespadocumentserializer.h index 08fe7ccdad9..ba3bf63afa7 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.h +++ b/document/src/vespa/document/serialization/vespadocumentserializer.h @@ -76,6 +76,7 @@ private: void write(const RemoveFieldPathUpdate &value); void write(const TensorModifyUpdate &value); void write(const TensorAddUpdate &value); + void write(const TensorRemoveUpdate &value); void visit(const DocumentUpdate &value) override { writeHEAD(value); } void visit(const FieldUpdate &value) override { write(value); } @@ -90,6 +91,7 @@ private: void visit(const RemoveFieldPathUpdate &value) override { write(value); } void visit(const TensorModifyUpdate &value) override; void visit(const TensorAddUpdate &value) override; + void visit(const TensorRemoveUpdate &value) override; void visit(const AnnotationReferenceFieldValue &value) override { write(value); } void visit(const ArrayFieldValue &value) override { write(value); } diff --git a/document/src/vespa/document/update/CMakeLists.txt b/document/src/vespa/document/update/CMakeLists.txt index 2ece7877bdb..83374adefbc 100644 --- a/document/src/vespa/document/update/CMakeLists.txt +++ b/document/src/vespa/document/update/CMakeLists.txt @@ -15,6 +15,7 @@ vespa_add_library(document_updates OBJECT removevalueupdate.cpp tensor_add_update.cpp tensor_modify_update.cpp + tensor_remove_update.cpp valueupdate.cpp DEPENDS AFTER diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp new file mode 100644 index 00000000000..3e2bb86c66b --- /dev/null +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -0,0 +1,130 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "tensor_remove_update.h" +#include <vespa/document/base/exceptions.h> +#include <vespa/document/datatype/tensor_data_type.h> +#include <vespa/document/fieldvalue/document.h> +#include <vespa/document/fieldvalue/tensorfieldvalue.h> +#include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/eval/tensor/tensor.h> +#include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/util/xmlstream.h> +#include <ostream> + +using vespalib::IllegalArgumentException; +using vespalib::IllegalStateException; +using vespalib::tensor::Tensor; +using vespalib::make_string; + +namespace document { + +IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate); + +TensorRemoveUpdate::TensorRemoveUpdate() + : _tensor() +{ +} + +TensorRemoveUpdate::TensorRemoveUpdate(const TensorRemoveUpdate &rhs) + : _tensor(rhs._tensor->clone()) +{ +} + +TensorRemoveUpdate::TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor) + : _tensor(std::move(tensor)) +{ +} + +TensorRemoveUpdate::~TensorRemoveUpdate() = default; + +TensorRemoveUpdate & +TensorRemoveUpdate::operator=(const TensorRemoveUpdate &rhs) +{ + _tensor.reset(rhs._tensor->clone()); + return *this; +} + +TensorRemoveUpdate & +TensorRemoveUpdate::operator=(TensorRemoveUpdate &&rhs) +{ + _tensor = std::move(rhs._tensor); + return *this; +} + +bool +TensorRemoveUpdate::operator==(const ValueUpdate &other) const +{ + if (other.getClass().id() != TensorRemoveUpdate::classId) { + return false; + } + const TensorRemoveUpdate& o(static_cast<const TensorRemoveUpdate&>(other)); + if (*_tensor != *o._tensor) { + return false; + } + return true; +} + +void +TensorRemoveUpdate::checkCompatibility(const Field &field) const +{ + if (field.getDataType().getClass().id() != TensorDataType::classId) { + throw IllegalArgumentException(make_string( + "Can not perform tensor remove update on non-tensor field '%s'.", + field.getName().data()), VESPA_STRLOC); + } +} + +std::unique_ptr<Tensor> +TensorRemoveUpdate::applyTo(const Tensor &tensor) const +{ + // TODO: implement + (void) tensor; + return std::unique_ptr<Tensor>(); +} + +bool +TensorRemoveUpdate::applyTo(FieldValue &value) const +{ + // TODO: implement + (void) value; + return false; +} + +void +TensorRemoveUpdate::printXml(XmlOutputStream &xos) const +{ + xos << "{TensorRemoveUpdate::printXml not yet implemented}"; +} + +void +TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &indent) const +{ + out << indent << "TensorRemoveUpdate("; + if (_tensor) { + _tensor->print(out, verbose, indent); + } + out << ")"; +} + +void +TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) +{ + auto tensor = type.createFieldValue(); + if (tensor->inherits(TensorFieldValue::classId)) { + _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); + } else { + std::string err = make_string( + "Expected tensor field value, got a \"%s\" field value.", tensor->getClass().name()); + throw IllegalStateException(err, VESPA_STRLOC); + } + VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion()); + deserializer.read(*_tensor); +} + +TensorRemoveUpdate * +TensorRemoveUpdate::clone() const +{ + return new TensorRemoveUpdate(*this); +} + +} diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h new file mode 100644 index 00000000000..7f2a32a8a3a --- /dev/null +++ b/document/src/vespa/document/update/tensor_remove_update.h @@ -0,0 +1,43 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "valueupdate.h" + +namespace vespalib::tensor { class Tensor; } + +namespace document { + +class TensorFieldValue; + +/** + * An update used to remove cells from a sparse tensor (has only mapped dimensions). + * + * The cells to remove are contained in a sparse tensor as well. + */ +class TensorRemoveUpdate : public ValueUpdate { +private: + std::unique_ptr<TensorFieldValue> _tensor; + + TensorRemoveUpdate(); + TensorRemoveUpdate(const TensorRemoveUpdate &rhs); + ACCEPT_UPDATE_VISITOR; + +public: + TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor); + ~TensorRemoveUpdate() override; + TensorRemoveUpdate &operator=(const TensorRemoveUpdate &rhs); + TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs); + const TensorFieldValue &getTensor() const { return *_tensor; } + std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const; + + bool operator==(const ValueUpdate &other) const override; + void checkCompatibility(const Field &field) const override; + bool applyTo(FieldValue &value) const override; + void printXml(XmlOutputStream &xos) const override; + void print(std::ostream &out, bool verbose, const std::string &indent) const override; + void deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) override; + TensorRemoveUpdate* clone() const override; + + DECLARE_IDENTIFIABLE(TensorRemoveUpdate); +}; + +} diff --git a/document/src/vespa/document/update/updates.h b/document/src/vespa/document/update/updates.h index 4e520e61690..3d775f4d734 100644 --- a/document/src/vespa/document/update/updates.h +++ b/document/src/vespa/document/update/updates.h @@ -2,14 +2,14 @@ #pragma once -#include "documentupdate.h" -#include "fieldupdate.h" #include "addvalueupdate.h" #include "arithmeticvalueupdate.h" #include "assignvalueupdate.h" #include "clearvalueupdate.h" +#include "documentupdate.h" +#include "fieldupdate.h" #include "mapvalueupdate.h" #include "removevalueupdate.h" #include "tensor_add_update.h" #include "tensor_modify_update.h" - +#include "tensor_remove_update.h" diff --git a/document/src/vespa/document/update/updatevisitor.h b/document/src/vespa/document/update/updatevisitor.h index f41e985f7c8..823d749d1f0 100644 --- a/document/src/vespa/document/update/updatevisitor.h +++ b/document/src/vespa/document/update/updatevisitor.h @@ -17,6 +17,7 @@ class AssignFieldPathUpdate; class RemoveFieldPathUpdate; class TensorAddUpdate; class TensorModifyUpdate; +class TensorRemoveUpdate; struct UpdateVisitor { virtual ~UpdateVisitor() {} @@ -34,6 +35,7 @@ struct UpdateVisitor { virtual void visit(const RemoveFieldPathUpdate &value) = 0; virtual void visit(const TensorModifyUpdate &value) = 0; virtual void visit(const TensorAddUpdate &value) = 0; + virtual void visit(const TensorRemoveUpdate &value) = 0; }; #define ACCEPT_UPDATE_VISITOR void accept(UpdateVisitor & visitor) const override { visitor.visit(*this); } diff --git a/document/src/vespa/document/util/identifiableid.h b/document/src/vespa/document/util/identifiableid.h index c8859cedb2e..9368b6a7cb6 100644 --- a/document/src/vespa/document/util/identifiableid.h +++ b/document/src/vespa/document/util/identifiableid.h @@ -70,6 +70,7 @@ #define CID_TensorModifyUpdate DOCUMENT_CID(100) #define CID_TensorAddUpdate DOCUMENT_CID(101) +#define CID_TensorRemoveUpdate DOCUMENT_CID(102) #define CID_document_DocumentUpdate DOCUMENT_CID(999) |