summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-04 08:44:15 +0100
committerGitHub <noreply@github.com>2019-02-04 08:44:15 +0100
commit1568c4effccafd2115ee95b1c691582ce381093f (patch)
tree9f014e9b7a6bd96538fc297ff87f94c9a45acae0 /document
parent07391639c56c639ecc6dbf74a5f6317f1caad458 (diff)
parent597afd85869374ed41d5b807e784e6de4c548163 (diff)
Merge pull request #8348 from vespa-engine/toregge/tensor-update-end-to-end
Tensor modify update end to end
Diffstat (limited to 'document')
-rw-r--r--document/src/tests/documentupdatetestcase.cpp13
-rw-r--r--document/src/vespa/document/update/tensormodifyupdate.cpp53
-rw-r--r--document/src/vespa/document/update/tensormodifyupdate.h3
-rw-r--r--document/src/vespa/document/update/valueupdate.h3
4 files changed, 62 insertions, 10 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index e0c6f8572e0..e7283849178 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -170,6 +170,11 @@ std::unique_ptr<Tensor> createTensorWith2Cells() {
{{{"x", "9"}, {"y", "9"}}, 11} }, {"x", "y"});
}
+std::unique_ptr<Tensor> createExpectedUpdatedTensorWith2Cells() {
+ return createTensor({ {{{"x", "8"}, {"y", "9"}}, 2},
+ {{{"x", "9"}, {"y", "9"}}, 11} }, {"x", "y"});
+}
+
FieldValue::UP createTensorFieldValueWith2Cells() {
auto fv(std::make_unique<TensorFieldValue>());
*fv = createTensorWith2Cells();
@@ -953,7 +958,8 @@ DocumentUpdateTest::testTensorModifyUpdate()
TestDocMan docMan;
Document::UP doc(docMan.createDocument());
Document updated(*doc);
- updated.setValue(updated.getField("tensor"), *createTensorFieldValueWith2Cells());
+ auto oldTensor = createTensorFieldValueWith2Cells();
+ updated.setValue(updated.getField("tensor"), *oldTensor);
CPPUNIT_ASSERT(*doc != updated);
testValueUpdate(*createTensorModifyUpdate(), *DataType::TENSOR);
DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId());
@@ -962,9 +968,8 @@ DocumentUpdateTest::testTensorModifyUpdate()
FieldValue::UP fval(updated.getValue("tensor"));
CPPUNIT_ASSERT(fval);
auto &tensor = asTensor(*fval);
- // TODO: Check that tensor is correctly modified.
- // For now, value is unchanged.
- CPPUNIT_ASSERT(tensor.equals(*createTensorWith2Cells()));
+ auto expectedUpdatedTensor = createExpectedUpdatedTensorWith2Cells();
+ CPPUNIT_ASSERT(tensor.equals(*expectedUpdatedTensor));
}
void
diff --git a/document/src/vespa/document/update/tensormodifyupdate.cpp b/document/src/vespa/document/update/tensormodifyupdate.cpp
index 87da385a57a..a02379e4991 100644
--- a/document/src/vespa/document/update/tensormodifyupdate.cpp
+++ b/document/src/vespa/document/update/tensormodifyupdate.cpp
@@ -1,12 +1,14 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "tensormodifyupdate.h"
-#include <vespa/document/base/field.h>
#include <vespa/document/base/exceptions.h>
+#include <vespa/document/base/field.h>
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
-#include <vespa/document/util/serializableexceptions.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
+#include <vespa/document/util/serializableexceptions.h>
+#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/tensor/cell_values.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/stllike/asciistream.h>
@@ -19,8 +21,37 @@ using vespalib::IllegalStateException;
using vespalib::tensor::Tensor;
using vespalib::make_string;
+using join_fun_t = double (*)(double, double);
+
namespace document {
+namespace {
+
+double
+replace(double, double b)
+{
+ return b;
+}
+
+join_fun_t
+getJoinFunction(TensorModifyUpdate::Operation operation)
+{
+ using Operation = TensorModifyUpdate::Operation;
+
+ switch (operation) {
+ case Operation::REPLACE:
+ return replace;
+ case Operation::ADD:
+ return vespalib::eval::operation::Add::f;
+ case Operation::MUL:
+ return vespalib::eval::operation::Mul::f;
+ default:
+ throw IllegalArgumentException("Bad operation", VESPA_STRLOC);
+ }
+}
+
+}
+
IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate);
TensorModifyUpdate::TensorModifyUpdate()
@@ -86,15 +117,27 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const
}
}
+std::unique_ptr<Tensor>
+TensorModifyUpdate::applyTo(const Tensor &tensor) const
+{
+ auto &cellTensor = _tensor->getAsTensorPtr();
+ if (cellTensor) {
+ vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellTensor));
+ return tensor.modify(getJoinFunction(_operation), cellValues);
+ }
+ return std::unique_ptr<Tensor>();
+}
+
bool
TensorModifyUpdate::applyTo(FieldValue& value) const
{
if (value.inherits(TensorFieldValue::classId)) {
TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
auto &oldTensor = tensorFieldValue.getAsTensorPtr();
- // TODO: Apply operation with tensor
- auto newTensor = oldTensor->clone();
- tensorFieldValue = std::move(newTensor);
+ auto newTensor = applyTo(*oldTensor);
+ if (newTensor) {
+ tensorFieldValue = std::move(newTensor);
+ }
} else {
std::string err = make_string(
"Unable to perform a tensor modify update on a \"%s\" field "
diff --git a/document/src/vespa/document/update/tensormodifyupdate.h b/document/src/vespa/document/update/tensormodifyupdate.h
index fd89c9da47b..dcb9bcf0470 100644
--- a/document/src/vespa/document/update/tensormodifyupdate.h
+++ b/document/src/vespa/document/update/tensormodifyupdate.h
@@ -2,6 +2,8 @@
#include "valueupdate.h"
+namespace vespalib::tensor { class Tensor; }
+
namespace document {
class TensorFieldValue;
@@ -37,6 +39,7 @@ public:
Operation getOperation() const { return _operation; }
const TensorFieldValue &getTensor() const { return *_tensor; }
void checkCompatibility(const Field &field) const override;
+ std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const;
bool applyTo(FieldValue &value) const override;
void printXml(XmlOutputStream &xos) const override;
void print(std::ostream &out, bool verbose, const std::string &indent) const override;
diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h
index 963e1ad1d96..ceb711074f4 100644
--- a/document/src/vespa/document/update/valueupdate.h
+++ b/document/src/vespa/document/update/valueupdate.h
@@ -53,7 +53,8 @@ public:
Assign = IDENTIFIABLE_CLASSID(AssignValueUpdate),
Clear = IDENTIFIABLE_CLASSID(ClearValueUpdate),
Map = IDENTIFIABLE_CLASSID(MapValueUpdate),
- Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate)
+ Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate),
+ TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate)
};
ValueUpdate()