summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2019-02-18 14:12:24 +0100
committerTor Egge <Tor.Egge@broadpark.no>2019-02-18 14:15:09 +0100
commitf14afd23ff875cc64bcff1a447580c485264fdc1 (patch)
tree4fc77424edd51456f87384a2717febc59de23dc2 /document
parent7baac9a29d01a23893b32d54b672001281bd3d96 (diff)
Use converted tensor type in TensorModifyUpdate.
Diffstat (limited to 'document')
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp35
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.h4
2 files changed, 32 insertions, 7 deletions
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index b75b2ceec60..a7f82cad3c7 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -21,6 +21,7 @@ using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
using vespalib::tensor::Tensor;
using vespalib::make_string;
+using vespalib::eval::ValueType;
using join_fun_t = double (*)(double, double);
@@ -68,26 +69,41 @@ getJoinFunctionName(TensorModifyUpdate::Operation operation)
}
}
+std::unique_ptr<const TensorDataType>
+convertToCompatibleType(const TensorDataType &tensorType)
+{
+ std::vector<ValueType::Dimension> list;
+ for (const auto &dim : tensorType.getTensorType().dimensions()) {
+ list.emplace_back(dim.name);
+ }
+ return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list)));
+}
+
}
IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate);
TensorModifyUpdate::TensorModifyUpdate()
: _operation(Operation::MAX_NUM_OPERATIONS),
+ _tensorType(),
_tensor()
{
}
TensorModifyUpdate::TensorModifyUpdate(const TensorModifyUpdate &rhs)
: _operation(rhs._operation),
- _tensor(rhs._tensor->clone())
+ _tensorType(rhs._tensorType->clone()),
+ _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
{
+ *_tensor = *rhs._tensor;
}
-TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor)
+TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor)
: _operation(operation),
- _tensor(std::move(tensor))
+ _tensorType(Identifiable::cast<const TensorDataType &>(*tensor->getDataType()).clone()),
+ _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
{
+ *_tensor = *tensor;
}
TensorModifyUpdate::~TensorModifyUpdate() = default;
@@ -95,8 +111,13 @@ TensorModifyUpdate::~TensorModifyUpdate() = default;
TensorModifyUpdate &
TensorModifyUpdate::operator=(const TensorModifyUpdate &rhs)
{
- _operation = rhs._operation;
- _tensor.reset(rhs._tensor->clone());
+ if (&rhs != this) {
+ _operation = rhs._operation;
+ _tensor.reset();
+ _tensorType.reset(rhs._tensorType->clone());
+ _tensor.reset(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()));
+ *_tensor = *rhs._tensor;
+ }
return *this;
}
@@ -104,6 +125,7 @@ TensorModifyUpdate &
TensorModifyUpdate::operator=(TensorModifyUpdate &&rhs)
{
_operation = rhs._operation;
+ _tensorType = std::move(rhs._tensorType);
_tensor = std::move(rhs._tensor);
return *this;
}
@@ -192,7 +214,8 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty
throw DeserializeException(msg.str(), VESPA_STRLOC);
}
_operation = static_cast<Operation>(op);
- auto tensor = type.createFieldValue();
+ _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
+ auto tensor = _tensorType->createFieldValue();
if (tensor->inherits(TensorFieldValue::classId)) {
_tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
} else {
diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h
index c937e356ce4..c2d61d3e69b 100644
--- a/document/src/vespa/document/update/tensor_modify_update.h
+++ b/document/src/vespa/document/update/tensor_modify_update.h
@@ -6,6 +6,7 @@ namespace vespalib::tensor { class Tensor; }
namespace document {
+class TensorDataType;
class TensorFieldValue;
/*
@@ -25,13 +26,14 @@ public:
};
private:
Operation _operation;
+ std::unique_ptr<const TensorDataType> _tensorType;
std::unique_ptr<TensorFieldValue> _tensor;
TensorModifyUpdate();
TensorModifyUpdate(const TensorModifyUpdate &rhs);
ACCEPT_UPDATE_VISITOR;
public:
- TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor);
+ TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor);
~TensorModifyUpdate() override;
TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs);
TensorModifyUpdate &operator=(TensorModifyUpdate &&rhs);