diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-04 08:44:15 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-02-04 08:44:15 +0100 |
commit | 1568c4effccafd2115ee95b1c691582ce381093f (patch) | |
tree | 9f014e9b7a6bd96538fc297ff87f94c9a45acae0 /document | |
parent | 07391639c56c639ecc6dbf74a5f6317f1caad458 (diff) | |
parent | 597afd85869374ed41d5b807e784e6de4c548163 (diff) |
Merge pull request #8348 from vespa-engine/toregge/tensor-update-end-to-end
Tensor modify update end to end
Diffstat (limited to 'document')
4 files changed, 62 insertions, 10 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index e0c6f8572e0..e7283849178 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -170,6 +170,11 @@ std::unique_ptr<Tensor> createTensorWith2Cells() { {{{"x", "9"}, {"y", "9"}}, 11} }, {"x", "y"}); } +std::unique_ptr<Tensor> createExpectedUpdatedTensorWith2Cells() { + return createTensor({ {{{"x", "8"}, {"y", "9"}}, 2}, + {{{"x", "9"}, {"y", "9"}}, 11} }, {"x", "y"}); +} + FieldValue::UP createTensorFieldValueWith2Cells() { auto fv(std::make_unique<TensorFieldValue>()); *fv = createTensorWith2Cells(); @@ -953,7 +958,8 @@ DocumentUpdateTest::testTensorModifyUpdate() TestDocMan docMan; Document::UP doc(docMan.createDocument()); Document updated(*doc); - updated.setValue(updated.getField("tensor"), *createTensorFieldValueWith2Cells()); + auto oldTensor = createTensorFieldValueWith2Cells(); + updated.setValue(updated.getField("tensor"), *oldTensor); CPPUNIT_ASSERT(*doc != updated); testValueUpdate(*createTensorModifyUpdate(), *DataType::TENSOR); DocumentUpdate upd(docMan.getTypeRepo(), *doc->getDataType(), doc->getId()); @@ -962,9 +968,8 @@ DocumentUpdateTest::testTensorModifyUpdate() FieldValue::UP fval(updated.getValue("tensor")); CPPUNIT_ASSERT(fval); auto &tensor = asTensor(*fval); - // TODO: Check that tensor is correctly modified. - // For now, value is unchanged. - CPPUNIT_ASSERT(tensor.equals(*createTensorWith2Cells())); + auto expectedUpdatedTensor = createExpectedUpdatedTensorWith2Cells(); + CPPUNIT_ASSERT(tensor.equals(*expectedUpdatedTensor)); } void diff --git a/document/src/vespa/document/update/tensormodifyupdate.cpp b/document/src/vespa/document/update/tensormodifyupdate.cpp index 87da385a57a..a02379e4991 100644 --- a/document/src/vespa/document/update/tensormodifyupdate.cpp +++ b/document/src/vespa/document/update/tensormodifyupdate.cpp @@ -1,12 +1,14 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensormodifyupdate.h" -#include <vespa/document/base/field.h> #include <vespa/document/base/exceptions.h> +#include <vespa/document/base/field.h> #include <vespa/document/fieldvalue/document.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> -#include <vespa/document/util/serializableexceptions.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/document/util/serializableexceptions.h> +#include <vespa/eval/eval/operation.h> +#include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/stllike/asciistream.h> @@ -19,8 +21,37 @@ using vespalib::IllegalStateException; using vespalib::tensor::Tensor; using vespalib::make_string; +using join_fun_t = double (*)(double, double); + namespace document { +namespace { + +double +replace(double, double b) +{ + return b; +} + +join_fun_t +getJoinFunction(TensorModifyUpdate::Operation operation) +{ + using Operation = TensorModifyUpdate::Operation; + + switch (operation) { + case Operation::REPLACE: + return replace; + case Operation::ADD: + return vespalib::eval::operation::Add::f; + case Operation::MUL: + return vespalib::eval::operation::Mul::f; + default: + throw IllegalArgumentException("Bad operation", VESPA_STRLOC); + } +} + +} + IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate); TensorModifyUpdate::TensorModifyUpdate() @@ -86,15 +117,27 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const } } +std::unique_ptr<Tensor> +TensorModifyUpdate::applyTo(const Tensor &tensor) const +{ + auto &cellTensor = _tensor->getAsTensorPtr(); + if (cellTensor) { + vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellTensor)); + return tensor.modify(getJoinFunction(_operation), cellValues); + } + return std::unique_ptr<Tensor>(); +} + bool TensorModifyUpdate::applyTo(FieldValue& value) const { if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); auto &oldTensor = tensorFieldValue.getAsTensorPtr(); - // TODO: Apply operation with tensor - auto newTensor = oldTensor->clone(); - tensorFieldValue = std::move(newTensor); + auto newTensor = applyTo(*oldTensor); + if (newTensor) { + tensorFieldValue = std::move(newTensor); + } } else { std::string err = make_string( "Unable to perform a tensor modify update on a \"%s\" field " diff --git a/document/src/vespa/document/update/tensormodifyupdate.h b/document/src/vespa/document/update/tensormodifyupdate.h index fd89c9da47b..dcb9bcf0470 100644 --- a/document/src/vespa/document/update/tensormodifyupdate.h +++ b/document/src/vespa/document/update/tensormodifyupdate.h @@ -2,6 +2,8 @@ #include "valueupdate.h" +namespace vespalib::tensor { class Tensor; } + namespace document { class TensorFieldValue; @@ -37,6 +39,7 @@ public: Operation getOperation() const { return _operation; } const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; + std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h index 963e1ad1d96..ceb711074f4 100644 --- a/document/src/vespa/document/update/valueupdate.h +++ b/document/src/vespa/document/update/valueupdate.h @@ -53,7 +53,8 @@ public: Assign = IDENTIFIABLE_CLASSID(AssignValueUpdate), Clear = IDENTIFIABLE_CLASSID(ClearValueUpdate), Map = IDENTIFIABLE_CLASSID(MapValueUpdate), - Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate) + Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate), + TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate) }; ValueUpdate() |