summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-03-08 08:19:48 +0100
committerGitHub <noreply@github.com>2019-03-08 08:19:48 +0100
commit85e978f6c57edafb6d24c4a9f47ae9760dc65c53 (patch)
treea748b9436dc1dbe425b2eb253f931b573df58719
parentb6d30eb9c7370d35accf7b3a709259ae2916a03e (diff)
parent9a86bbe4f3353c04794400179d9296b7a5d60a53 (diff)
Merge pull request #8705 from vespa-engine/toregge/convert-tensor-type-for-tensor-remove-updates
Convert field tensor type to tensor type for tensor remove updates.
-rw-r--r--document/src/tests/documentupdatetestcase.cpp2
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp40
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.h4
3 files changed, 38 insertions, 8 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 71bc545897b..b9568d546c5 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1042,7 +1042,7 @@ TEST(DocumentUpdateTest, tensor_remove_update_throws_if_address_tensor_is_not_sp
auto addressTensor = f.makeTensor(f.spec().add({{"x", 0}}, 2)); // creates a dense address tensor
ASSERT_THROW(
f.assertRoundtripSerialize(TensorRemoveUpdate(std::move(addressTensor))),
- vespalib::IllegalStateException);
+ document::WrongTensorTypeException);
}
TEST(DocumentUpdateTest, tensor_modify_update_throws_if_cells_tensor_is_not_sparse)
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index c72d776fa9f..24aba4ece5a 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -17,24 +17,45 @@ using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
using vespalib::tensor::Tensor;
using vespalib::make_string;
+using vespalib::eval::ValueType;
namespace document {
+namespace {
+
+std::unique_ptr<const TensorDataType>
+convertToCompatibleType(const TensorDataType &tensorType)
+{
+ std::vector<ValueType::Dimension> list;
+ for (const auto &dim : tensorType.getTensorType().dimensions()) {
+ if (dim.is_mapped()) {
+ list.emplace_back(dim.name);
+ }
+ }
+ return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list)));
+}
+
+}
+
IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate);
TensorRemoveUpdate::TensorRemoveUpdate()
- : _tensor()
+ : _tensorType(),
+ _tensor()
{
}
TensorRemoveUpdate::TensorRemoveUpdate(const TensorRemoveUpdate &rhs)
- : _tensor(rhs._tensor->clone())
+ : _tensorType(rhs._tensorType->clone()),
+ _tensor(rhs._tensor->clone())
{
}
-TensorRemoveUpdate::TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor)
- : _tensor(std::move(tensor))
+TensorRemoveUpdate::TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> tensor)
+ : _tensorType(Identifiable::cast<const TensorDataType &>(*tensor->getDataType()).clone()),
+ _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
{
+ *_tensor = *tensor;
}
TensorRemoveUpdate::~TensorRemoveUpdate() = default;
@@ -42,13 +63,19 @@ TensorRemoveUpdate::~TensorRemoveUpdate() = default;
TensorRemoveUpdate &
TensorRemoveUpdate::operator=(const TensorRemoveUpdate &rhs)
{
- _tensor.reset(rhs._tensor->clone());
+ if (&rhs != this) {
+ _tensor.reset();
+ _tensorType.reset(rhs._tensorType->clone());
+ _tensor.reset(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()));
+ *_tensor = *rhs._tensor;
+ }
return *this;
}
TensorRemoveUpdate &
TensorRemoveUpdate::operator=(TensorRemoveUpdate &&rhs)
{
+ _tensorType = std::move(rhs._tensorType);
_tensor = std::move(rhs._tensor);
return *this;
}
@@ -138,7 +165,8 @@ verifyAddressTensorIsSparse(const std::unique_ptr<Tensor> &addressTensor)
void
TensorRemoveUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &type, nbostream &stream)
{
- auto tensor = type.createFieldValue();
+ _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
+ auto tensor = _tensorType->createFieldValue();
if (tensor->inherits(TensorFieldValue::classId)) {
_tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
} else {
diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h
index 809e9d42305..e75348fa829 100644
--- a/document/src/vespa/document/update/tensor_remove_update.h
+++ b/document/src/vespa/document/update/tensor_remove_update.h
@@ -6,6 +6,7 @@ namespace vespalib::tensor { class Tensor; }
namespace document {
+class TensorDataType;
class TensorFieldValue;
/**
@@ -16,6 +17,7 @@ class TensorFieldValue;
*/
class TensorRemoveUpdate : public ValueUpdate {
private:
+ std::unique_ptr<const TensorDataType> _tensorType;
std::unique_ptr<TensorFieldValue> _tensor;
TensorRemoveUpdate();
@@ -23,7 +25,7 @@ private:
ACCEPT_UPDATE_VISITOR;
public:
- TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> &&tensor);
+ TensorRemoveUpdate(std::unique_ptr<TensorFieldValue> tensor);
~TensorRemoveUpdate() override;
TensorRemoveUpdate &operator=(const TensorRemoveUpdate &rhs);
TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs);