diff options
author | Tor Egge <Tor.Egge@broadpark.no> | 2019-02-18 14:12:24 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@broadpark.no> | 2019-02-18 14:15:09 +0100 |
commit | f14afd23ff875cc64bcff1a447580c485264fdc1 (patch) | |
tree | 4fc77424edd51456f87384a2717febc59de23dc2 | |
parent | 7baac9a29d01a23893b32d54b672001281bd3d96 (diff) |
Use converted tensor type in TensorModifyUpdate.
-rw-r--r-- | document/src/vespa/document/update/tensor_modify_update.cpp | 35 | ||||
-rw-r--r-- | document/src/vespa/document/update/tensor_modify_update.h | 4 |
2 files changed, 32 insertions, 7 deletions
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index b75b2ceec60..a7f82cad3c7 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -21,6 +21,7 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; using vespalib::tensor::Tensor; using vespalib::make_string; +using vespalib::eval::ValueType; using join_fun_t = double (*)(double, double); @@ -68,26 +69,41 @@ getJoinFunctionName(TensorModifyUpdate::Operation operation) } } +std::unique_ptr<const TensorDataType> +convertToCompatibleType(const TensorDataType &tensorType) +{ + std::vector<ValueType::Dimension> list; + for (const auto &dim : tensorType.getTensorType().dimensions()) { + list.emplace_back(dim.name); + } + return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list))); +} + } IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate); TensorModifyUpdate::TensorModifyUpdate() : _operation(Operation::MAX_NUM_OPERATIONS), + _tensorType(), _tensor() { } TensorModifyUpdate::TensorModifyUpdate(const TensorModifyUpdate &rhs) : _operation(rhs._operation), - _tensor(rhs._tensor->clone()) + _tensorType(rhs._tensorType->clone()), + _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())) { + *_tensor = *rhs._tensor; } -TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor) +TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor) : _operation(operation), - _tensor(std::move(tensor)) + _tensorType(Identifiable::cast<const TensorDataType &>(*tensor->getDataType()).clone()), + _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())) { + *_tensor = *tensor; } TensorModifyUpdate::~TensorModifyUpdate() = default; @@ -95,8 +111,13 @@ TensorModifyUpdate::~TensorModifyUpdate() = default; TensorModifyUpdate & TensorModifyUpdate::operator=(const TensorModifyUpdate &rhs) { - _operation = rhs._operation; - _tensor.reset(rhs._tensor->clone()); + if (&rhs != this) { + _operation = rhs._operation; + _tensor.reset(); + _tensorType.reset(rhs._tensorType->clone()); + _tensor.reset(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())); + *_tensor = *rhs._tensor; + } return *this; } @@ -104,6 +125,7 @@ TensorModifyUpdate & TensorModifyUpdate::operator=(TensorModifyUpdate &&rhs) { _operation = rhs._operation; + _tensorType = std::move(rhs._tensorType); _tensor = std::move(rhs._tensor); return *this; } @@ -192,7 +214,8 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty throw DeserializeException(msg.str(), VESPA_STRLOC); } _operation = static_cast<Operation>(op); - auto tensor = type.createFieldValue(); + _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type)); + auto tensor = _tensorType->createFieldValue(); if (tensor->inherits(TensorFieldValue::classId)) { _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); } else { diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h index c937e356ce4..c2d61d3e69b 100644 --- a/document/src/vespa/document/update/tensor_modify_update.h +++ b/document/src/vespa/document/update/tensor_modify_update.h @@ -6,6 +6,7 @@ namespace vespalib::tensor { class Tensor; } namespace document { +class TensorDataType; class TensorFieldValue; /* @@ -25,13 +26,14 @@ public: }; private: Operation _operation; + std::unique_ptr<const TensorDataType> _tensorType; std::unique_ptr<TensorFieldValue> _tensor; TensorModifyUpdate(); TensorModifyUpdate(const TensorModifyUpdate &rhs); ACCEPT_UPDATE_VISITOR; public: - TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor); + TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor); ~TensorModifyUpdate() override; TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs); TensorModifyUpdate &operator=(TensorModifyUpdate &&rhs); |