aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-15 13:13:43 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-02-15 13:13:43 +0000
commit81999faf9b314681419ac0e00ff921efc5566d90 (patch)
treeaee3a054d3240e0be5295d21f9408d563c478596
parentf0352c8c527ed03c7bc82ba22aaeee14d38ef516 (diff)
Implement skeleton of TensorRemoveUpdate with support for (de)-serialization.
-rw-r--r--document/src/tests/documentupdatetestcase.cpp10
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.cpp29
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.h2
-rw-r--r--document/src/vespa/document/update/CMakeLists.txt1
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp130
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.h43
-rw-r--r--document/src/vespa/document/update/updates.h6
-rw-r--r--document/src/vespa/document/update/updatevisitor.h2
-rw-r--r--document/src/vespa/document/util/identifiableid.h1
9 files changed, 213 insertions, 11 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index ddc209234c5..b351299f2d1 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -17,6 +17,7 @@
#include <vespa/document/update/removevalueupdate.h>
#include <vespa/document/update/tensor_add_update.h>
#include <vespa/document/update/tensor_modify_update.h>
+#include <vespa/document/update/tensor_remove_update.h>
#include <vespa/document/update/valueupdate.h>
#include <vespa/document/util/bytebuffer.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
@@ -63,6 +64,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture {
void tensor_modify_update_can_be_applied();
void tensor_assign_update_can_be_roundtrip_serialized();
void tensor_add_update_can_be_roundtrip_serialized();
+ void tensor_remove_update_can_be_roundtrip_serialized();
void tensor_modify_update_can_be_roundtrip_serialized();
void testThatDocumentUpdateFlagsIsWorking();
void testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50();
@@ -96,6 +98,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture {
CPPUNIT_TEST(tensor_modify_update_can_be_applied);
CPPUNIT_TEST(tensor_assign_update_can_be_roundtrip_serialized);
CPPUNIT_TEST(tensor_add_update_can_be_roundtrip_serialized);
+ CPPUNIT_TEST(tensor_remove_update_can_be_roundtrip_serialized);
CPPUNIT_TEST(tensor_modify_update_can_be_roundtrip_serialized);
CPPUNIT_TEST(testThatDocumentUpdateFlagsIsWorking);
CPPUNIT_TEST(testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50);
@@ -1062,6 +1065,13 @@ DocumentUpdateTest::tensor_add_update_can_be_roundtrip_serialized()
}
void
+DocumentUpdateTest::tensor_remove_update_can_be_roundtrip_serialized()
+{
+ TensorUpdateFixture f;
+ f.assertRoundtripSerialize(TensorRemoveUpdate(f.makeBaselineTensor()));
+}
+
+void
DocumentUpdateTest::tensor_modify_update_can_be_roundtrip_serialized()
{
TensorUpdateFixture f;
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
index 8364b560198..0d6703b4e97 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
@@ -1,9 +1,11 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include "vespadocumentserializer.h"
#include "annotationserializer.h"
#include "slime_output_to_vector.h"
#include "util.h"
+#include "vespadocumentserializer.h"
+#include <vespa/document/datatype/weightedsetdatatype.h>
+#include <vespa/document/fieldset/fieldsets.h>
#include <vespa/document/fieldvalue/annotationreferencefieldvalue.h>
#include <vespa/document/fieldvalue/arrayfieldvalue.h>
#include <vespa/document/fieldvalue/boolfieldvalue.h>
@@ -16,20 +18,18 @@
#include <vespa/document/fieldvalue/mapfieldvalue.h>
#include <vespa/document/fieldvalue/predicatefieldvalue.h>
#include <vespa/document/fieldvalue/rawfieldvalue.h>
+#include <vespa/document/fieldvalue/referencefieldvalue.h>
#include <vespa/document/fieldvalue/shortfieldvalue.h>
#include <vespa/document/fieldvalue/stringfieldvalue.h>
-#include <vespa/document/fieldvalue/weightedsetfieldvalue.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
-#include <vespa/document/fieldvalue/referencefieldvalue.h>
-#include <vespa/document/datatype/weightedsetdatatype.h>
-#include <vespa/document/update/updates.h>
+#include <vespa/document/fieldvalue/weightedsetfieldvalue.h>
#include <vespa/document/update/fieldpathupdates.h>
+#include <vespa/document/update/updates.h>
#include <vespa/document/util/bytebuffer.h>
-#include <vespa/document/fieldset/fieldsets.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/vespalib/data/databuffer.h>
#include <vespa/vespalib/data/slime/binary_format.h>
#include <vespa/vespalib/objects/nbostream.h>
-#include <vespa/vespalib/data/databuffer.h>
-#include <vespa/eval/tensor/serialization/typed_binary_format.h>
#include <vespa/vespalib/util/compressor.h>
using std::make_pair;
@@ -594,4 +594,17 @@ VespaDocumentSerializer::visit(const TensorAddUpdate &value)
write(value);
}
+void
+VespaDocumentSerializer::write(const TensorRemoveUpdate &value)
+{
+ _stream << TensorRemoveUpdate::classId;
+ write(value.getTensor());
+}
+
+void
+VespaDocumentSerializer::visit(const TensorRemoveUpdate &value)
+{
+ write(value);
+}
+
} // namespace document
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.h b/document/src/vespa/document/serialization/vespadocumentserializer.h
index 08fe7ccdad9..ba3bf63afa7 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.h
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.h
@@ -76,6 +76,7 @@ private:
void write(const RemoveFieldPathUpdate &value);
void write(const TensorModifyUpdate &value);
void write(const TensorAddUpdate &value);
+ void write(const TensorRemoveUpdate &value);
void visit(const DocumentUpdate &value) override { writeHEAD(value); }
void visit(const FieldUpdate &value) override { write(value); }
@@ -90,6 +91,7 @@ private:
void visit(const RemoveFieldPathUpdate &value) override { write(value); }
void visit(const TensorModifyUpdate &value) override;
void visit(const TensorAddUpdate &value) override;
+ void visit(const TensorRemoveUpdate &value) override;
void visit(const AnnotationReferenceFieldValue &value) override { write(value); }
void visit(const ArrayFieldValue &value) override { write(value); }
diff --git a/document/src/vespa/document/update/CMakeLists.txt b/document/src/vespa/document/update/CMakeLists.txt
index 2ece7877bdb..83374adefbc 100644
--- a/document/src/vespa/document/update/CMakeLists.txt
+++ b/document/src/vespa/document/update/CMakeLists.txt
@@ -15,6 +15,7 @@ vespa_add_library(document_updates OBJECT
removevalueupdate.cpp
tensor_add_update.cpp
tensor_modify_update.cpp
+ tensor_remove_update.cpp
valueupdate.cpp
DEPENDS
AFTER
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
new file mode 100644
index 00000000000..3e2bb86c66b
--- /dev/null
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -0,0 +1,130 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "tensor_remove_update.h"
+#include <vespa/document/base/exceptions.h>
+#include <vespa/document/datatype/tensor_data_type.h>
+#include <vespa/document/fieldvalue/document.h>
+#include <vespa/document/fieldvalue/tensorfieldvalue.h>
+#include <vespa/document/serialization/vespadocumentdeserializer.h>
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/vespalib/util/xmlstream.h>
+#include <ostream>
+
+using vespalib::IllegalArgumentException;
+using vespalib::IllegalStateException;
+using vespalib::tensor::Tensor;
+using vespalib::make_string;
+
+namespace document {
+
+IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate);
+
+TensorRemoveUpdate::TensorRemoveUpdate()
+ : _tensor()
+{
+}
+
+TensorRemoveUpdate::TensorRemoveUpdate(const TensorRemoveUpdate &rhs)
+ : _tensor(rhs._tensor->clone())
+{
+}
+
+TensorRemoveUpdate::TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor)
+ : _tensor(std::move(tensor))
+{
+}
+
+TensorRemoveUpdate::~TensorRemoveUpdate() = default;
+
+TensorRemoveUpdate &
+TensorRemoveUpdate::operator=(const TensorRemoveUpdate &rhs)
+{
+ _tensor.reset(rhs._tensor->clone());
+ return *this;
+}
+
+TensorRemoveUpdate &
+TensorRemoveUpdate::operator=(TensorRemoveUpdate &&rhs)
+{
+ _tensor = std::move(rhs._tensor);
+ return *this;
+}
+
+bool
+TensorRemoveUpdate::operator==(const ValueUpdate &other) const
+{
+ if (other.getClass().id() != TensorRemoveUpdate::classId) {
+ return false;
+ }
+ const TensorRemoveUpdate& o(static_cast<const TensorRemoveUpdate&>(other));
+ if (*_tensor != *o._tensor) {
+ return false;
+ }
+ return true;
+}
+
+void
+TensorRemoveUpdate::checkCompatibility(const Field &field) const
+{
+ if (field.getDataType().getClass().id() != TensorDataType::classId) {
+ throw IllegalArgumentException(make_string(
+ "Can not perform tensor remove update on non-tensor field '%s'.",
+ field.getName().data()), VESPA_STRLOC);
+ }
+}
+
+std::unique_ptr<Tensor>
+TensorRemoveUpdate::applyTo(const Tensor &tensor) const
+{
+ // TODO: implement
+ (void) tensor;
+ return std::unique_ptr<Tensor>();
+}
+
+bool
+TensorRemoveUpdate::applyTo(FieldValue &value) const
+{
+ // TODO: implement
+ (void) value;
+ return false;
+}
+
+void
+TensorRemoveUpdate::printXml(XmlOutputStream &xos) const
+{
+ xos << "{TensorRemoveUpdate::printXml not yet implemented}";
+}
+
+void
+TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &indent) const
+{
+ out << indent << "TensorRemoveUpdate(";
+ if (_tensor) {
+ _tensor->print(out, verbose, indent);
+ }
+ out << ")";
+}
+
+void
+TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream)
+{
+ auto tensor = type.createFieldValue();
+ if (tensor->inherits(TensorFieldValue::classId)) {
+ _tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
+ } else {
+ std::string err = make_string(
+ "Expected tensor field value, got a \"%s\" field value.", tensor->getClass().name());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
+ deserializer.read(*_tensor);
+}
+
+TensorRemoveUpdate *
+TensorRemoveUpdate::clone() const
+{
+ return new TensorRemoveUpdate(*this);
+}
+
+}
diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h
new file mode 100644
index 00000000000..7f2a32a8a3a
--- /dev/null
+++ b/document/src/vespa/document/update/tensor_remove_update.h
@@ -0,0 +1,43 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "valueupdate.h"
+
+namespace vespalib::tensor { class Tensor; }
+
+namespace document {
+
+class TensorFieldValue;
+
+/**
+ * An update used to remove cells from a sparse tensor (has only mapped dimensions).
+ *
+ * The cells to remove are contained in a sparse tensor as well.
+ */
+class TensorRemoveUpdate : public ValueUpdate {
+private:
+ std::unique_ptr<TensorFieldValue> _tensor;
+
+ TensorRemoveUpdate();
+ TensorRemoveUpdate(const TensorRemoveUpdate &rhs);
+ ACCEPT_UPDATE_VISITOR;
+
+public:
+ TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor);
+ ~TensorRemoveUpdate() override;
+ TensorRemoveUpdate &operator=(const TensorRemoveUpdate &rhs);
+ TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs);
+ const TensorFieldValue &getTensor() const { return *_tensor; }
+ std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const;
+
+ bool operator==(const ValueUpdate &other) const override;
+ void checkCompatibility(const Field &field) const override;
+ bool applyTo(FieldValue &value) const override;
+ void printXml(XmlOutputStream &xos) const override;
+ void print(std::ostream &out, bool verbose, const std::string &indent) const override;
+ void deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) override;
+ TensorRemoveUpdate* clone() const override;
+
+ DECLARE_IDENTIFIABLE(TensorRemoveUpdate);
+};
+
+}
diff --git a/document/src/vespa/document/update/updates.h b/document/src/vespa/document/update/updates.h
index 4e520e61690..3d775f4d734 100644
--- a/document/src/vespa/document/update/updates.h
+++ b/document/src/vespa/document/update/updates.h
@@ -2,14 +2,14 @@
#pragma once
-#include "documentupdate.h"
-#include "fieldupdate.h"
#include "addvalueupdate.h"
#include "arithmeticvalueupdate.h"
#include "assignvalueupdate.h"
#include "clearvalueupdate.h"
+#include "documentupdate.h"
+#include "fieldupdate.h"
#include "mapvalueupdate.h"
#include "removevalueupdate.h"
#include "tensor_add_update.h"
#include "tensor_modify_update.h"
-
+#include "tensor_remove_update.h"
diff --git a/document/src/vespa/document/update/updatevisitor.h b/document/src/vespa/document/update/updatevisitor.h
index f41e985f7c8..823d749d1f0 100644
--- a/document/src/vespa/document/update/updatevisitor.h
+++ b/document/src/vespa/document/update/updatevisitor.h
@@ -17,6 +17,7 @@ class AssignFieldPathUpdate;
class RemoveFieldPathUpdate;
class TensorAddUpdate;
class TensorModifyUpdate;
+class TensorRemoveUpdate;
struct UpdateVisitor {
virtual ~UpdateVisitor() {}
@@ -34,6 +35,7 @@ struct UpdateVisitor {
virtual void visit(const RemoveFieldPathUpdate &value) = 0;
virtual void visit(const TensorModifyUpdate &value) = 0;
virtual void visit(const TensorAddUpdate &value) = 0;
+ virtual void visit(const TensorRemoveUpdate &value) = 0;
};
#define ACCEPT_UPDATE_VISITOR void accept(UpdateVisitor & visitor) const override { visitor.visit(*this); }
diff --git a/document/src/vespa/document/util/identifiableid.h b/document/src/vespa/document/util/identifiableid.h
index c8859cedb2e..9368b6a7cb6 100644
--- a/document/src/vespa/document/util/identifiableid.h
+++ b/document/src/vespa/document/util/identifiableid.h
@@ -70,6 +70,7 @@
#define CID_TensorModifyUpdate DOCUMENT_CID(100)
#define CID_TensorAddUpdate DOCUMENT_CID(101)
+#define CID_TensorRemoveUpdate DOCUMENT_CID(102)
#define CID_document_DocumentUpdate DOCUMENT_CID(999)