diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-01-12 14:25:31 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-01-12 14:28:07 +0000 |
commit | f5a24c0ffd9dbb2e4fa5a325562c5a016ed027e0 (patch) | |
tree | 0e238e458a13b6976dbc33ac85e5cd9940f4ceca | |
parent | 00df4dedab5ea94283cbd2d2f359b01774402ffb (diff) |
avoid extra Value copy when updating DirectTensorAttribute
12 files changed, 94 insertions, 25 deletions
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index c8ce728172e..d6c1b3667a6 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -83,10 +83,15 @@ TensorAddUpdate::checkCompatibility(const Field& field) const std::unique_ptr<vespalib::eval::Value> TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const { - auto addTensor = _tensor->getAsTensorPtr(); - if (addTensor) { - const auto &factory = FastValueBuilderFactory::get(); - return TensorPartialUpdate::add(tensor, *addTensor, factory); + return apply_to(tensor, FastValueBuilderFactory::get()); +} + +std::unique_ptr<vespalib::eval::Value> +TensorAddUpdate::apply_to(const Value &old_tensor, + const ValueBuilderFactory &factory) const +{ + if (auto addTensor = _tensor->getAsTensorPtr()) { + return TensorPartialUpdate::add(old_tensor, *addTensor, factory); } return {}; } @@ -98,6 +103,7 @@ TensorAddUpdate::applyTo(FieldValue& value) const TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); tensorFieldValue.make_empty_if_not_existing(); auto oldTensor = tensorFieldValue.getAsTensorPtr(); + assert(oldTensor); auto newTensor = applyTo(*oldTensor); if (newTensor) { tensorFieldValue = std::move(newTensor); diff --git a/document/src/vespa/document/update/tensor_add_update.h b/document/src/vespa/document/update/tensor_add_update.h index 59fe8a845ac..1da90446643 100644 --- a/document/src/vespa/document/update/tensor_add_update.h +++ b/document/src/vespa/document/update/tensor_add_update.h @@ -1,8 +1,9 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tensor_update.h" #include "valueupdate.h" -namespace vespalib::eval { struct Value; } +namespace vespalib::eval { struct Value; struct ValueBuilderFactory; } namespace document { @@ -13,7 +14,7 @@ class TensorFieldValue; * * The cells to add are contained in a tensor of the same type. */ -class TensorAddUpdate : public ValueUpdate { +class TensorAddUpdate : public ValueUpdate, public TensorUpdate { std::unique_ptr<TensorFieldValue> _tensor; TensorAddUpdate(); @@ -28,6 +29,8 @@ public: const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; + std::unique_ptr<Value> apply_to(const Value &tensor, + const ValueBuilderFactory &factory) const override; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index f7fca784ab2..791c3efe872 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -160,10 +160,16 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const std::unique_ptr<vespalib::eval::Value> TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const { - auto cellsTensor = _tensor->getAsTensorPtr(); - if (cellsTensor) { - const auto &factory = FastValueBuilderFactory::get(); - return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, factory); + return apply_to(tensor, FastValueBuilderFactory::get()); +} + +std::unique_ptr<vespalib::eval::Value> +TensorModifyUpdate::apply_to(const Value &old_tensor, + const ValueBuilderFactory &factory) const +{ + if (auto cellsTensor = _tensor->getAsTensorPtr()) { + auto op = getJoinFunction(_operation); + return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory); } return {}; } diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h index fc8fac720ac..633b2b32db0 100644 --- a/document/src/vespa/document/update/tensor_modify_update.h +++ b/document/src/vespa/document/update/tensor_modify_update.h @@ -1,5 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tensor_update.h" #include "valueupdate.h" namespace vespalib::eval { struct Value; } @@ -15,7 +16,7 @@ class TensorFieldValue; * The operand is represented as a tensor field value containing a * mapped (aka sparse) tensor. */ -class TensorModifyUpdate : public ValueUpdate { +class TensorModifyUpdate : public ValueUpdate, public TensorUpdate { public: /** Declare all types of tensor modify updates. */ enum class Operation { // Operation to be applied to matching tensor cells @@ -42,6 +43,8 @@ public: const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; + std::unique_ptr<Value> apply_to(const Value &tensor, + const ValueBuilderFactory &factory) const override; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index b3a7e93c86a..5c8c5c07116 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -107,10 +107,15 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const std::unique_ptr<vespalib::eval::Value> TensorRemoveUpdate::applyTo(const vespalib::eval::Value &tensor) const { - auto addressTensor = _tensor->getAsTensorPtr(); - if (addressTensor) { - const auto &factory = FastValueBuilderFactory::get(); - return TensorPartialUpdate::remove(tensor, *addressTensor, factory); + return apply_to(tensor, FastValueBuilderFactory::get()); +} + +std::unique_ptr<vespalib::eval::Value> +TensorRemoveUpdate::apply_to(const Value &old_tensor, + const ValueBuilderFactory &factory) const +{ + if (auto addressTensor = _tensor->getAsTensorPtr()) { + return TensorPartialUpdate::remove(old_tensor, *addressTensor, factory); } return {}; } diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h index 3efc4e37e80..5590e7d307f 100644 --- a/document/src/vespa/document/update/tensor_remove_update.h +++ b/document/src/vespa/document/update/tensor_remove_update.h @@ -1,5 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tensor_update.h" #include "valueupdate.h" namespace vespalib::eval { struct Value; } @@ -15,7 +16,7 @@ class TensorFieldValue; * The cells to remove are contained in a sparse tensor (with all mapped dimensions) where cell values are set to 1.0. * When used on a mixed tensor the entire dense sub-space (pointed to by a cell in the sparse tensor) is removed. */ -class TensorRemoveUpdate : public ValueUpdate { +class TensorRemoveUpdate : public ValueUpdate, public TensorUpdate { private: std::unique_ptr<const TensorDataType> _tensorType; std::unique_ptr<TensorFieldValue> _tensor; @@ -31,7 +32,8 @@ public: TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs); const TensorFieldValue &getTensor() const { return *_tensor; } std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; - + std::unique_ptr<Value> apply_to(const Value &tensor, + const ValueBuilderFactory &factory) const override; bool operator==(const ValueUpdate &other) const override; void checkCompatibility(const Field &field) const override; bool applyTo(FieldValue &value) const override; diff --git a/document/src/vespa/document/update/tensor_update.h b/document/src/vespa/document/update/tensor_update.h new file mode 100644 index 00000000000..ecb99b849c0 --- /dev/null +++ b/document/src/vespa/document/update/tensor_update.h @@ -0,0 +1,21 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <functional> +#include <memory> + +namespace vespalib::eval { struct Value; struct ValueBuilderFactory; } + +namespace document { + +struct TensorUpdate { +protected: + ~TensorUpdate() = default; +public: + using Value = vespalib::eval::Value; + using ValueBuilderFactory = vespalib::eval::ValueBuilderFactory; + virtual std::unique_ptr<Value> apply_to(const Value &tensor, const ValueBuilderFactory &factory) const = 0; +}; + +} diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp index 063ecffe729..8fb165cb0b8 100644 --- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp @@ -207,9 +207,8 @@ AttributeUpdater::handleUpdate(PredicateAttribute &vec, uint32_t lid, const Valu namespace { -template <typename TensorUpdateType> void -applyTensorUpdate(TensorAttribute &vec, uint32_t lid, const TensorUpdateType &update, +applyTensorUpdate(TensorAttribute &vec, uint32_t lid, const document::TensorUpdate &update, bool create_empty_if_non_existing) { auto oldTensor = vec.getTensor(lid); @@ -217,10 +216,7 @@ applyTensorUpdate(TensorAttribute &vec, uint32_t lid, const TensorUpdateType &up oldTensor = vec.getEmptyTensor(); } if (oldTensor) { - auto newTensor = update.applyTo(*oldTensor); - if (newTensor) { - vec.setTensor(lid, *newTensor); - } + vec.update_tensor(lid, update, *oldTensor); } } diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp index 8cda62682d0..f4010857c76 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp @@ -76,6 +76,17 @@ DirectTensorAttribute::setTensor(DocId lid, const vespalib::eval::Value &tensor) set_tensor(lid, FastValueBuilderFactory::get().copy(tensor)); } +void +DirectTensorAttribute::update_tensor(DocId docId, + const document::TensorUpdate &update, + const vespalib::eval::Value &old_tensor) +{ + auto new_value = update.apply_to(old_tensor, FastValueBuilderFactory::get()); + if (new_value) { + set_tensor(docId, std::move(new_value)); + } +} + std::unique_ptr<vespalib::eval::Value> DirectTensorAttribute::getTensor(DocId docId) const { diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h index a49b3c751d9..a87526342ef 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h @@ -17,6 +17,9 @@ public: DirectTensorAttribute(vespalib::stringref baseFileName, const Config &cfg); virtual ~DirectTensorAttribute(); virtual void setTensor(DocId docId, const vespalib::eval::Value &tensor) override; + void update_tensor(DocId docId, + const document::TensorUpdate &update, + const vespalib::eval::Value &old_tensor) override; virtual std::unique_ptr<vespalib::eval::Value> getTensor(DocId docId) const override; virtual bool onLoad() override; virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index e0b21290284..f5de1de640f 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -13,6 +13,7 @@ #include <vespa/eval/eval/value.h> using document::TensorDataType; +using document::TensorUpdate; using document::WrongTensorTypeException; using vespalib::eval::FastValueBuilderFactory; using vespalib::eval::TensorSpec; @@ -250,6 +251,15 @@ TensorAttribute::getRefCopy() const return RefCopyVector(&_refVector[0], &_refVector[0] + size); } +void +TensorAttribute::update_tensor(DocId docId, + const document::TensorUpdate &update, + const vespalib::eval::Value &old_tensor) +{ + auto new_value = update.apply_to(old_tensor, FastValueBuilderFactory::get()); + setTensor(docId, *new_value); +} + std::unique_ptr<PrepareResult> TensorAttribute::prepare_set_tensor(DocId docid, const vespalib::eval::Value& tensor) const { diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h index 7abfe66a2e4..adb9e7bca8c 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h @@ -7,8 +7,9 @@ #include "tensor_store.h" #include <vespa/searchlib/attribute/not_implemented_attribute.h> #include <vespa/vespalib/util/rcuvector.h> +#include <vespa/document/update/tensor_update.h> -namespace vespalib::eval { struct Value; } +namespace vespalib::eval { struct Value; struct ValueBuilderFactory; } namespace search::tensor { @@ -58,7 +59,9 @@ public: uint32_t getVersion() const override; RefCopyVector getRefCopy() const; virtual void setTensor(DocId docId, const vespalib::eval::Value &tensor) = 0; - + virtual void update_tensor(DocId docId, + const document::TensorUpdate &update, + const vespalib::eval::Value &oldTensor); /** * Performs the prepare step in a two-phase operation to set a tensor for a document. * |