summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/src/tests/documentupdatetestcase.cpp32
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.cpp13
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.h2
-rw-r--r--document/src/vespa/document/update/CMakeLists.txt1
-rw-r--r--document/src/vespa/document/update/tensoraddupdate.cpp142
-rw-r--r--document/src/vespa/document/update/tensoraddupdate.h40
-rw-r--r--document/src/vespa/document/update/updates.h1
-rw-r--r--document/src/vespa/document/update/updatevisitor.h2
-rw-r--r--document/src/vespa/document/update/valueupdate.h3
-rw-r--r--document/src/vespa/document/util/identifiableid.h1
10 files changed, 236 insertions, 1 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index e7283849178..c7660e5d527 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -12,6 +12,7 @@
#include <vespa/document/update/fieldupdate.h>
#include <vespa/document/update/mapvalueupdate.h>
#include <vespa/document/update/removevalueupdate.h>
+#include <vespa/document/update/tensoraddupdate.h>
#include <vespa/document/update/tensormodifyupdate.h>
#include <vespa/document/update/valueupdate.h>
#include <vespa/document/serialization/vespadocumentserializer.h>
@@ -60,6 +61,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture {
void testMapValueUpdate();
void testTensorAssignUpdate();
void testTensorClearUpdate();
+ void testTensorAddUpdate();
void testTensorModifyUpdate();
void testThatDocumentUpdateFlagsIsWorking();
void testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50();
@@ -89,6 +91,7 @@ struct DocumentUpdateTest : public CppUnit::TestFixture {
CPPUNIT_TEST(testMapValueUpdate);
CPPUNIT_TEST(testTensorAssignUpdate);
CPPUNIT_TEST(testTensorClearUpdate);
+ CPPUNIT_TEST(testTensorAddUpdate);
CPPUNIT_TEST(testTensorModifyUpdate);
CPPUNIT_TEST(testThatDocumentUpdateFlagsIsWorking);
CPPUNIT_TEST(testThatCreateIfNonExistentFlagIsSerialized50AndDeserialized50);
@@ -181,6 +184,14 @@ FieldValue::UP createTensorFieldValueWith2Cells() {
return std::move(fv);
}
+std::unique_ptr<TensorAddUpdate> createTensorAddUpdate() {
+ auto tensorFieldValue(std::make_unique<TensorFieldValue>());
+ *tensorFieldValue = createTensor({ {{{"x", "8"}, {"y", "8"}}, 2},
+ {{{"x", "8"}, {"y", "9"}}, 2} }, {"x", "y"});
+ auto update = std::make_unique<TensorAddUpdate>(std::move(tensorFieldValue));
+ return update;
+}
+
std::unique_ptr<TensorModifyUpdate> createTensorModifyUpdate() {
auto tensorFieldValue(std::make_unique<TensorFieldValue>());
*tensorFieldValue = createTensor({ {{{"x", "8"}, {"y", "9"}}, 2} }, {"x", "y"});
@@ -953,6 +964,27 @@ DocumentUpdateTest::testTensorClearUpdate()
}
void
+DocumentUpdateTest::testTensorAddUpdate()
+{
+ TestDocMan docMan;
+ Document::UP doc(docMan.createDocument());
+ Document updated(*doc);
+ auto oldTensor = createTensorFieldValueWith2Cells();
+ updated.setValue(updated.getField("tensor"), *oldTensor);
+ CPPUNIT_ASSERT(*doc != updated);
+ testValueUpdate(*createTensorAddUpdate(), *DataType::TENSOR);
+ DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId());
+ upd.addUpdate(FieldUpdate(upd.getType().getField("tensor")).addUpdate(*createTensorAddUpdate()));
+ upd.applyTo(updated);
+ FieldValue::UP fval(updated.getValue("tensor"));
+ CPPUNIT_ASSERT(fval);
+ auto &tensor = asTensor(*fval);
+ // Note: Placeholder value for now
+ auto expectedUpdatedTensor = createTensorWith2Cells();
+ CPPUNIT_ASSERT(tensor.equals(*expectedUpdatedTensor));
+}
+
+void
DocumentUpdateTest::testTensorModifyUpdate()
{
TestDocMan docMan;
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
index 7960cc7934a..8364b560198 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
@@ -581,4 +581,17 @@ VespaDocumentSerializer::visit(const TensorModifyUpdate &value)
write(value);
}
+void
+VespaDocumentSerializer::write(const TensorAddUpdate &value)
+{
+ _stream << TensorAddUpdate::classId;
+ write(value.getTensor());
+}
+
+void
+VespaDocumentSerializer::visit(const TensorAddUpdate &value)
+{
+ write(value);
+}
+
} // namespace document
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.h b/document/src/vespa/document/serialization/vespadocumentserializer.h
index b2885d10968..08fe7ccdad9 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.h
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.h
@@ -75,6 +75,7 @@ private:
void write(const AssignFieldPathUpdate &value);
void write(const RemoveFieldPathUpdate &value);
void write(const TensorModifyUpdate &value);
+ void write(const TensorAddUpdate &value);
void visit(const DocumentUpdate &value) override { writeHEAD(value); }
void visit(const FieldUpdate &value) override { write(value); }
@@ -88,6 +89,7 @@ private:
void visit(const AssignFieldPathUpdate &value) override { write(value); }
void visit(const RemoveFieldPathUpdate &value) override { write(value); }
void visit(const TensorModifyUpdate &value) override;
+ void visit(const TensorAddUpdate &value) override;
void visit(const AnnotationReferenceFieldValue &value) override { write(value); }
void visit(const ArrayFieldValue &value) override { write(value); }
diff --git a/document/src/vespa/document/update/CMakeLists.txt b/document/src/vespa/document/update/CMakeLists.txt
index fc3a6fb5495..34f539ee4aa 100644
--- a/document/src/vespa/document/update/CMakeLists.txt
+++ b/document/src/vespa/document/update/CMakeLists.txt
@@ -13,6 +13,7 @@ vespa_add_library(document_updates OBJECT
mapvalueupdate.cpp
removefieldpathupdate.cpp
removevalueupdate.cpp
+ tensoraddupdate.cpp
tensormodifyupdate.cpp
valueupdate.cpp
DEPENDS
diff --git a/document/src/vespa/document/update/tensoraddupdate.cpp b/document/src/vespa/document/update/tensoraddupdate.cpp
new file mode 100644
index 00000000000..eb708d9f651
--- /dev/null
+++ b/document/src/vespa/document/update/tensoraddupdate.cpp
@@ -0,0 +1,142 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "tensoraddupdate.h"
+#include <vespa/document/base/exceptions.h>
+#include <vespa/document/base/field.h>
+#include <vespa/document/fieldvalue/document.h>
+#include <vespa/document/fieldvalue/tensorfieldvalue.h>
+#include <vespa/document/serialization/vespadocumentdeserializer.h>
+#include <vespa/document/util/serializableexceptions.h>
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/vespalib/stllike/asciistream.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/xmlstream.h>
+#include <ostream>
+
+using vespalib::IllegalArgumentException;
+using vespalib::IllegalStateException;
+using vespalib::tensor::Tensor;
+using vespalib::make_string;
+
+namespace document {
+
+IMPLEMENT_IDENTIFIABLE(TensorAddUpdate, ValueUpdate);
+
+TensorAddUpdate::TensorAddUpdate()
+ : _tensor()
+{
+}
+
+TensorAddUpdate::TensorAddUpdate(const TensorAddUpdate &rhs)
+ : _tensor(rhs._tensor->clone())
+{
+}
+
+TensorAddUpdate::TensorAddUpdate(std::unique_ptr<TensorFieldValue> &&tensor)
+ : _tensor(std::move(tensor))
+{
+}
+
+TensorAddUpdate::~TensorAddUpdate() = default;
+
+TensorAddUpdate &
+TensorAddUpdate::operator=(const TensorAddUpdate &rhs)
+{
+ _tensor.reset(rhs._tensor->clone());
+ return *this;
+}
+
+TensorAddUpdate &
+TensorAddUpdate::operator=(TensorAddUpdate &&rhs)
+{
+ _tensor = std::move(rhs._tensor);
+ return *this;
+}
+
+bool
+TensorAddUpdate::operator==(const ValueUpdate &other) const
+{
+ if (other.getClass().id() != TensorAddUpdate::classId) {
+ return false;
+ }
+ const TensorAddUpdate& o(static_cast<const TensorAddUpdate&>(other));
+ if (*_tensor != *o._tensor) {
+ return false;
+ }
+ return true;
+}
+
+
+void
+TensorAddUpdate::checkCompatibility(const Field& field) const
+{
+ if (field.getDataType() != *DataType::TENSOR) {
+ throw IllegalArgumentException(make_string(
+ "Can not perform tensor add update on non-tensor field '%s'.",
+ field.getName().data()), VESPA_STRLOC);
+ }
+}
+
+std::unique_ptr<Tensor>
+TensorAddUpdate::applyTo(const Tensor &tensor) const
+{
+ return tensor.clone();
+}
+
+bool
+TensorAddUpdate::applyTo(FieldValue& value) const
+{
+ if (value.inherits(TensorFieldValue::classId)) {
+ TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
+ auto &oldTensor = tensorFieldValue.getAsTensorPtr();
+ auto newTensor = applyTo(*oldTensor);
+ if (newTensor) {
+ tensorFieldValue = std::move(newTensor);
+ }
+ } else {
+ std::string err = make_string(
+ "Unable to perform a tensor add update on a \"%s\" field "
+ "value.", value.getClass().name());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ return true;
+}
+
+void
+TensorAddUpdate::printXml(XmlOutputStream& xos) const
+{
+ xos << "{TensorAddUpdate::printXml not yet implemented}";
+}
+
+void
+TensorAddUpdate::print(std::ostream& out, bool verbose, const std::string& indent) const
+{
+ (void) verbose;
+ (void) indent;
+ out << "{TensorAddUpdate::print not yet implemented}";
+}
+
+void
+TensorAddUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream & stream)
+{
+ auto tensor = type.createFieldValue();
+ if (tensor->inherits(TensorFieldValue::classId)) {
+ _tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
+ } else {
+ std::string err = make_string(
+ "Expected tensor field value, got a \"%s\" field "
+ "value.", tensor->getClass().name());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ VespaDocumentDeserializer deserializer(repo, stream, Document::getNewestSerializationVersion());
+ deserializer.read(*_tensor);
+}
+
+TensorAddUpdate*
+TensorAddUpdate::clone() const
+{
+ return new TensorAddUpdate(*this);
+}
+
+}
diff --git a/document/src/vespa/document/update/tensoraddupdate.h b/document/src/vespa/document/update/tensoraddupdate.h
new file mode 100644
index 00000000000..52e44ea33f3
--- /dev/null
+++ b/document/src/vespa/document/update/tensoraddupdate.h
@@ -0,0 +1,40 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "valueupdate.h"
+
+namespace vespalib::tensor { struct Tensor; }
+
+namespace document {
+
+class TensorFieldValue;
+
+/*
+ * An update used to add cells to a sparse tensor (has only mapped dimensions).
+ *
+ * The cells to add are contained in a sparse tensor as well.
+ */
+class TensorAddUpdate : public ValueUpdate {
+ std::unique_ptr<TensorFieldValue> _tensor;
+
+ TensorAddUpdate();
+ TensorAddUpdate(const TensorAddUpdate &rhs);
+ ACCEPT_UPDATE_VISITOR;
+public:
+ TensorAddUpdate(std::unique_ptr<TensorFieldValue> &&tensor);
+ ~TensorAddUpdate() override;
+ TensorAddUpdate &operator=(const TensorAddUpdate &rhs);
+ TensorAddUpdate &operator=(TensorAddUpdate &&rhs);
+ bool operator==(const ValueUpdate &other) const override;
+ const TensorFieldValue &getTensor() const { return *_tensor; }
+ void checkCompatibility(const Field &field) const override;
+ std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const;
+ bool applyTo(FieldValue &value) const override;
+ void printXml(XmlOutputStream &xos) const override;
+ void print(std::ostream &out, bool verbose, const std::string &indent) const override;
+ void deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) override;
+ TensorAddUpdate* clone() const override;
+
+ DECLARE_IDENTIFIABLE(TensorAddUpdate);
+};
+
+}
diff --git a/document/src/vespa/document/update/updates.h b/document/src/vespa/document/update/updates.h
index 013ce70b6b6..1609c5bc3a3 100644
--- a/document/src/vespa/document/update/updates.h
+++ b/document/src/vespa/document/update/updates.h
@@ -11,4 +11,5 @@
#include "mapvalueupdate.h"
#include "removevalueupdate.h"
#include "tensormodifyupdate.h"
+#include "tensoraddupdate.h"
diff --git a/document/src/vespa/document/update/updatevisitor.h b/document/src/vespa/document/update/updatevisitor.h
index e6291f90f69..f41e985f7c8 100644
--- a/document/src/vespa/document/update/updatevisitor.h
+++ b/document/src/vespa/document/update/updatevisitor.h
@@ -15,6 +15,7 @@ class MapValueUpdate;
class AddFieldPathUpdate;
class AssignFieldPathUpdate;
class RemoveFieldPathUpdate;
+class TensorAddUpdate;
class TensorModifyUpdate;
struct UpdateVisitor {
@@ -32,6 +33,7 @@ struct UpdateVisitor {
virtual void visit(const AssignFieldPathUpdate &value) = 0;
virtual void visit(const RemoveFieldPathUpdate &value) = 0;
virtual void visit(const TensorModifyUpdate &value) = 0;
+ virtual void visit(const TensorAddUpdate &value) = 0;
};
#define ACCEPT_UPDATE_VISITOR void accept(UpdateVisitor & visitor) const override { visitor.visit(*this); }
diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h
index ceb711074f4..0e15943f8e4 100644
--- a/document/src/vespa/document/update/valueupdate.h
+++ b/document/src/vespa/document/update/valueupdate.h
@@ -54,7 +54,8 @@ public:
Clear = IDENTIFIABLE_CLASSID(ClearValueUpdate),
Map = IDENTIFIABLE_CLASSID(MapValueUpdate),
Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate),
- TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate)
+ TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate),
+ TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate)
};
ValueUpdate()
diff --git a/document/src/vespa/document/util/identifiableid.h b/document/src/vespa/document/util/identifiableid.h
index 80d43369a0a..c8859cedb2e 100644
--- a/document/src/vespa/document/util/identifiableid.h
+++ b/document/src/vespa/document/util/identifiableid.h
@@ -69,6 +69,7 @@
#define CID_RemoveFieldPathUpdate DOCUMENT_CID(88)
#define CID_TensorModifyUpdate DOCUMENT_CID(100)
+#define CID_TensorAddUpdate DOCUMENT_CID(101)
#define CID_document_DocumentUpdate DOCUMENT_CID(999)