diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-18 16:03:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-02-18 16:03:47 +0100 |
commit | 9dfc17783ce64ce6d0857d3b6904875d7304531f (patch) | |
tree | b7c4d2692bdc27cbcb2693c0f06cba95b8df57fa | |
parent | 129010bad4bf802d33362a581b98bca74f4b036e (diff) | |
parent | f14afd23ff875cc64bcff1a447580c485264fdc1 (diff) |
Merge pull request #8537 from vespa-engine/toregge/use-converted-tensor-type-in-tensor-modify-update
Use converted tensor type in TensorModifyUpdate.
-rw-r--r-- | document/src/vespa/document/update/tensor_modify_update.cpp | 35 | ||||
-rw-r--r-- | document/src/vespa/document/update/tensor_modify_update.h | 4 |
2 files changed, 32 insertions, 7 deletions
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index b75b2ceec60..a7f82cad3c7 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -21,6 +21,7 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; using vespalib::tensor::Tensor; using vespalib::make_string; +using vespalib::eval::ValueType; using join_fun_t = double (*)(double, double); @@ -68,26 +69,41 @@ getJoinFunctionName(TensorModifyUpdate::Operation operation) } } +std::unique_ptr<const TensorDataType> +convertToCompatibleType(const TensorDataType &tensorType) +{ + std::vector<ValueType::Dimension> list; + for (const auto &dim : tensorType.getTensorType().dimensions()) { + list.emplace_back(dim.name); + } + return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list))); +} + } IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate); TensorModifyUpdate::TensorModifyUpdate() : _operation(Operation::MAX_NUM_OPERATIONS), + _tensorType(), _tensor() { } TensorModifyUpdate::TensorModifyUpdate(const TensorModifyUpdate &rhs) : _operation(rhs._operation), - _tensor(rhs._tensor->clone()) + _tensorType(rhs._tensorType->clone()), + _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())) { + *_tensor = *rhs._tensor; } -TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor) +TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor) : _operation(operation), - _tensor(std::move(tensor)) + _tensorType(Identifiable::cast<const TensorDataType &>(*tensor->getDataType()).clone()), + _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())) { + *_tensor = *tensor; } TensorModifyUpdate::~TensorModifyUpdate() = default; @@ -95,8 +111,13 @@ TensorModifyUpdate::~TensorModifyUpdate() = default; TensorModifyUpdate & TensorModifyUpdate::operator=(const TensorModifyUpdate &rhs) { - _operation = rhs._operation; - _tensor.reset(rhs._tensor->clone()); + if (&rhs != this) { + _operation = rhs._operation; + _tensor.reset(); + _tensorType.reset(rhs._tensorType->clone()); + _tensor.reset(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())); + *_tensor = *rhs._tensor; + } return *this; } @@ -104,6 +125,7 @@ TensorModifyUpdate & TensorModifyUpdate::operator=(TensorModifyUpdate &&rhs) { _operation = rhs._operation; + _tensorType = std::move(rhs._tensorType); _tensor = std::move(rhs._tensor); return *this; } @@ -192,7 +214,8 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty throw DeserializeException(msg.str(), VESPA_STRLOC); } _operation = static_cast<Operation>(op); - auto tensor = type.createFieldValue(); + _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type)); + auto tensor = _tensorType->createFieldValue(); if (tensor->inherits(TensorFieldValue::classId)) { _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); } else { diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h index c937e356ce4..c2d61d3e69b 100644 --- a/document/src/vespa/document/update/tensor_modify_update.h +++ b/document/src/vespa/document/update/tensor_modify_update.h @@ -6,6 +6,7 @@ namespace vespalib::tensor { class Tensor; } namespace document { +class TensorDataType; class TensorFieldValue; /* @@ -25,13 +26,14 @@ public: }; private: Operation _operation; + std::unique_ptr<const TensorDataType> _tensorType; std::unique_ptr<TensorFieldValue> _tensor; TensorModifyUpdate(); TensorModifyUpdate(const TensorModifyUpdate &rhs); ACCEPT_UPDATE_VISITOR; public: - TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor); + TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor); ~TensorModifyUpdate() override; TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs); TensorModifyUpdate &operator=(TensorModifyUpdate &&rhs); |