From f5a24c0ffd9dbb2e4fa5a325562c5a016ed027e0 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 12 Jan 2021 14:25:31 +0000 Subject: avoid extra Value copy when updating DirectTensorAttribute --- .../src/vespa/document/update/tensor_add_update.cpp | 14 ++++++++++---- .../src/vespa/document/update/tensor_add_update.h | 7 +++++-- .../vespa/document/update/tensor_modify_update.cpp | 14 ++++++++++---- .../vespa/document/update/tensor_modify_update.h | 5 ++++- .../vespa/document/update/tensor_remove_update.cpp | 13 +++++++++---- .../vespa/document/update/tensor_remove_update.h | 6 ++++-- document/src/vespa/document/update/tensor_update.h | 21 +++++++++++++++++++++ 7 files changed, 63 insertions(+), 17 deletions(-) create mode 100644 document/src/vespa/document/update/tensor_update.h (limited to 'document') diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index c8ce728172e..d6c1b3667a6 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -83,10 +83,15 @@ TensorAddUpdate::checkCompatibility(const Field& field) const std::unique_ptr TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const { - auto addTensor = _tensor->getAsTensorPtr(); - if (addTensor) { - const auto &factory = FastValueBuilderFactory::get(); - return TensorPartialUpdate::add(tensor, *addTensor, factory); + return apply_to(tensor, FastValueBuilderFactory::get()); +} + +std::unique_ptr +TensorAddUpdate::apply_to(const Value &old_tensor, + const ValueBuilderFactory &factory) const +{ + if (auto addTensor = _tensor->getAsTensorPtr()) { + return TensorPartialUpdate::add(old_tensor, *addTensor, factory); } return {}; } @@ -98,6 +103,7 @@ TensorAddUpdate::applyTo(FieldValue& value) const TensorFieldValue &tensorFieldValue = static_cast(value); tensorFieldValue.make_empty_if_not_existing(); auto oldTensor = tensorFieldValue.getAsTensorPtr(); + assert(oldTensor); auto newTensor = applyTo(*oldTensor); if (newTensor) { tensorFieldValue = std::move(newTensor); diff --git a/document/src/vespa/document/update/tensor_add_update.h b/document/src/vespa/document/update/tensor_add_update.h index 59fe8a845ac..1da90446643 100644 --- a/document/src/vespa/document/update/tensor_add_update.h +++ b/document/src/vespa/document/update/tensor_add_update.h @@ -1,8 +1,9 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tensor_update.h" #include "valueupdate.h" -namespace vespalib::eval { struct Value; } +namespace vespalib::eval { struct Value; struct ValueBuilderFactory; } namespace document { @@ -13,7 +14,7 @@ class TensorFieldValue; * * The cells to add are contained in a tensor of the same type. */ -class TensorAddUpdate : public ValueUpdate { +class TensorAddUpdate : public ValueUpdate, public TensorUpdate { std::unique_ptr _tensor; TensorAddUpdate(); @@ -28,6 +29,8 @@ public: const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; std::unique_ptr applyTo(const vespalib::eval::Value &tensor) const; + std::unique_ptr apply_to(const Value &tensor, + const ValueBuilderFactory &factory) const override; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index f7fca784ab2..791c3efe872 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -160,10 +160,16 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const std::unique_ptr TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const { - auto cellsTensor = _tensor->getAsTensorPtr(); - if (cellsTensor) { - const auto &factory = FastValueBuilderFactory::get(); - return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, factory); + return apply_to(tensor, FastValueBuilderFactory::get()); +} + +std::unique_ptr +TensorModifyUpdate::apply_to(const Value &old_tensor, + const ValueBuilderFactory &factory) const +{ + if (auto cellsTensor = _tensor->getAsTensorPtr()) { + auto op = getJoinFunction(_operation); + return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory); } return {}; } diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h index fc8fac720ac..633b2b32db0 100644 --- a/document/src/vespa/document/update/tensor_modify_update.h +++ b/document/src/vespa/document/update/tensor_modify_update.h @@ -1,5 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tensor_update.h" #include "valueupdate.h" namespace vespalib::eval { struct Value; } @@ -15,7 +16,7 @@ class TensorFieldValue; * The operand is represented as a tensor field value containing a * mapped (aka sparse) tensor. */ -class TensorModifyUpdate : public ValueUpdate { +class TensorModifyUpdate : public ValueUpdate, public TensorUpdate { public: /** Declare all types of tensor modify updates. */ enum class Operation { // Operation to be applied to matching tensor cells @@ -42,6 +43,8 @@ public: const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; std::unique_ptr applyTo(const vespalib::eval::Value &tensor) const; + std::unique_ptr apply_to(const Value &tensor, + const ValueBuilderFactory &factory) const override; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index b3a7e93c86a..5c8c5c07116 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -107,10 +107,15 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const std::unique_ptr TensorRemoveUpdate::applyTo(const vespalib::eval::Value &tensor) const { - auto addressTensor = _tensor->getAsTensorPtr(); - if (addressTensor) { - const auto &factory = FastValueBuilderFactory::get(); - return TensorPartialUpdate::remove(tensor, *addressTensor, factory); + return apply_to(tensor, FastValueBuilderFactory::get()); +} + +std::unique_ptr +TensorRemoveUpdate::apply_to(const Value &old_tensor, + const ValueBuilderFactory &factory) const +{ + if (auto addressTensor = _tensor->getAsTensorPtr()) { + return TensorPartialUpdate::remove(old_tensor, *addressTensor, factory); } return {}; } diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h index 3efc4e37e80..5590e7d307f 100644 --- a/document/src/vespa/document/update/tensor_remove_update.h +++ b/document/src/vespa/document/update/tensor_remove_update.h @@ -1,5 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "tensor_update.h" #include "valueupdate.h" namespace vespalib::eval { struct Value; } @@ -15,7 +16,7 @@ class TensorFieldValue; * The cells to remove are contained in a sparse tensor (with all mapped dimensions) where cell values are set to 1.0. * When used on a mixed tensor the entire dense sub-space (pointed to by a cell in the sparse tensor) is removed. */ -class TensorRemoveUpdate : public ValueUpdate { +class TensorRemoveUpdate : public ValueUpdate, public TensorUpdate { private: std::unique_ptr _tensorType; std::unique_ptr _tensor; @@ -31,7 +32,8 @@ public: TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs); const TensorFieldValue &getTensor() const { return *_tensor; } std::unique_ptr applyTo(const vespalib::eval::Value &tensor) const; - + std::unique_ptr apply_to(const Value &tensor, + const ValueBuilderFactory &factory) const override; bool operator==(const ValueUpdate &other) const override; void checkCompatibility(const Field &field) const override; bool applyTo(FieldValue &value) const override; diff --git a/document/src/vespa/document/update/tensor_update.h b/document/src/vespa/document/update/tensor_update.h new file mode 100644 index 00000000000..ecb99b849c0 --- /dev/null +++ b/document/src/vespa/document/update/tensor_update.h @@ -0,0 +1,21 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include +#include + +namespace vespalib::eval { struct Value; struct ValueBuilderFactory; } + +namespace document { + +struct TensorUpdate { +protected: + ~TensorUpdate() = default; +public: + using Value = vespalib::eval::Value; + using ValueBuilderFactory = vespalib::eval::ValueBuilderFactory; + virtual std::unique_ptr apply_to(const Value &tensor, const ValueBuilderFactory &factory) const = 0; +}; + +} -- cgit v1.2.3