diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-03-08 08:19:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-03-08 08:19:48 +0100 |
commit | 85e978f6c57edafb6d24c4a9f47ae9760dc65c53 (patch) | |
tree | a748b9436dc1dbe425b2eb253f931b573df58719 | |
parent | b6d30eb9c7370d35accf7b3a709259ae2916a03e (diff) | |
parent | 9a86bbe4f3353c04794400179d9296b7a5d60a53 (diff) |
Merge pull request #8705 from vespa-engine/toregge/convert-tensor-type-for-tensor-remove-updates
Convert field tensor type to tensor type for tensor remove updates.
-rw-r--r-- | document/src/tests/documentupdatetestcase.cpp | 2 | ||||
-rw-r--r-- | document/src/vespa/document/update/tensor_remove_update.cpp | 40 | ||||
-rw-r--r-- | document/src/vespa/document/update/tensor_remove_update.h | 4 |
3 files changed, 38 insertions, 8 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 71bc545897b..b9568d546c5 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1042,7 +1042,7 @@ TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sp auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor ASSERT_THROW( f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))), - vespalib::IllegalStateException); + document::WrongTensorTypeException); } TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse) diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index c72d776fa9f..24aba4ece5a 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -17,24 +17,45 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; using vespalib::tensor::Tensor; using vespalib::make_string; +using vespalib::eval::ValueType; namespace document { +namespace { + +std::unique_ptr<const TensorDataType> +convertToCompatibleType(const TensorDataType &tensorType) +{ + std::vector<ValueType::Dimension> list; + for (const auto &dim : tensorType.getTensorType().dimensions()) { + if (dim.is_mapped()) { + list.emplace_back(dim.name); + } + } + return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list))); +} + +} + IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate); TensorRemoveUpdate::TensorRemoveUpdate() - : _tensor() + : _tensorType(), + _tensor() { } TensorRemoveUpdate::TensorRemoveUpdate(const TensorRemoveUpdate &rhs) - : _tensor(rhs._tensor->clone()) + : _tensorType(rhs._tensorType->clone()), + _tensor(rhs._tensor->clone()) { } -TensorRemoveUpdate::TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor) - : _tensor(std::move(tensor)) +TensorRemoveUpdate::TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> tensor) + : _tensorType(Identifiable::cast<const TensorDataType &>(*tensor->getDataType()).clone()), + _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())) { + *_tensor = *tensor; } TensorRemoveUpdate::~TensorRemoveUpdate() = default; @@ -42,13 +63,19 @@ TensorRemoveUpdate::~TensorRemoveUpdate() = default; TensorRemoveUpdate & TensorRemoveUpdate::operator=(const TensorRemoveUpdate &rhs) { - _tensor.reset(rhs._tensor->clone()); + if (&rhs != this) { + _tensor.reset(); + _tensorType.reset(rhs._tensorType->clone()); + _tensor.reset(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release())); + *_tensor = *rhs._tensor; + } return *this; } TensorRemoveUpdate & TensorRemoveUpdate::operator=(TensorRemoveUpdate &&rhs) { + _tensorType = std::move(rhs._tensorType); _tensor = std::move(rhs._tensor); return *this; } @@ -138,7 +165,8 @@ verifyAddressTensorIsSparse(const std::unique_ptr<Tensor> &addressTensor) void TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream) { - auto tensor = type.createFieldValue(); + _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type)); + auto tensor = _tensorType->createFieldValue(); if (tensor->inherits(TensorFieldValue::classId)) { _tensor.reset(static_cast<TensorFieldValue *>(tensor.release())); } else { diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h index 809e9d42305..e75348fa829 100644 --- a/document/src/vespa/document/update/tensor_remove_update.h +++ b/document/src/vespa/document/update/tensor_remove_update.h @@ -6,6 +6,7 @@ namespace vespalib::tensor { class Tensor; } namespace document { +class TensorDataType; class TensorFieldValue; /** @@ -16,6 +17,7 @@ class TensorFieldValue; */ class TensorRemoveUpdate : public ValueUpdate { private: + std::unique_ptr<const TensorDataType> _tensorType; std::unique_ptr<TensorFieldValue> _tensor; TensorRemoveUpdate(); @@ -23,7 +25,7 @@ private: ACCEPT_UPDATE_VISITOR; public: - TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor); + TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> tensor); ~TensorRemoveUpdate() override; TensorRemoveUpdate &operator=(const TensorRemoveUpdate &rhs); TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs); |