summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-18 16:03:47 +0100
committerGitHub <noreply@github.com>2019-02-18 16:03:47 +0100
commit9dfc17783ce64ce6d0857d3b6904875d7304531f (patch)
treeb7c4d2692bdc27cbcb2693c0f06cba95b8df57fa
parent129010bad4bf802d33362a581b98bca74f4b036e (diff)
parentf14afd23ff875cc64bcff1a447580c485264fdc1 (diff)
Merge pull request #8537 from vespa-engine/toregge/use-converted-tensor-type-in-tensor-modify-update
Use converted tensor type in TensorModifyUpdate.
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp35
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.h4
2 files changed, 32 insertions, 7 deletions
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index b75b2ceec60..a7f82cad3c7 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -21,6 +21,7 @@ using vespalib::IllegalArgumentException;
using vespalib::IllegalStateException;
using vespalib::tensor::Tensor;
using vespalib::make_string;
+using vespalib::eval::ValueType;
using join_fun_t = double (*)(double, double);
@@ -68,26 +69,41 @@ getJoinFunctionName(TensorModifyUpdate::Operation operation)
}
}
+std::unique_ptr<const TensorDataType>
+convertToCompatibleType(const TensorDataType &tensorType)
+{
+ std::vector<ValueType::Dimension> list;
+ for (const auto &dim : tensorType.getTensorType().dimensions()) {
+ list.emplace_back(dim.name);
+ }
+ return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list)));
+}
+
}
IMPLEMENT_IDENTIFIABLE(TensorModifyUpdate, ValueUpdate);
TensorModifyUpdate::TensorModifyUpdate()
: _operation(Operation::MAX_NUM_OPERATIONS),
+ _tensorType(),
_tensor()
{
}
TensorModifyUpdate::TensorModifyUpdate(const TensorModifyUpdate &rhs)
: _operation(rhs._operation),
- _tensor(rhs._tensor->clone())
+ _tensorType(rhs._tensorType->clone()),
+ _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
{
+ *_tensor = *rhs._tensor;
}
-TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor)
+TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor)
: _operation(operation),
- _tensor(std::move(tensor))
+ _tensorType(Identifiable::cast<const TensorDataType &>(*tensor->getDataType()).clone()),
+ _tensor(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
{
+ *_tensor = *tensor;
}
TensorModifyUpdate::~TensorModifyUpdate() = default;
@@ -95,8 +111,13 @@ TensorModifyUpdate::~TensorModifyUpdate() = default;
TensorModifyUpdate &
TensorModifyUpdate::operator=(const TensorModifyUpdate &rhs)
{
- _operation = rhs._operation;
- _tensor.reset(rhs._tensor->clone());
+ if (&rhs != this) {
+ _operation = rhs._operation;
+ _tensor.reset();
+ _tensorType.reset(rhs._tensorType->clone());
+ _tensor.reset(Identifiable::cast<TensorFieldValue *>(_tensorType->createFieldValue().release()));
+ *_tensor = *rhs._tensor;
+ }
return *this;
}
@@ -104,6 +125,7 @@ TensorModifyUpdate &
TensorModifyUpdate::operator=(TensorModifyUpdate &&rhs)
{
_operation = rhs._operation;
+ _tensorType = std::move(rhs._tensorType);
_tensor = std::move(rhs._tensor);
return *this;
}
@@ -192,7 +214,8 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty
throw DeserializeException(msg.str(), VESPA_STRLOC);
}
_operation = static_cast<Operation>(op);
- auto tensor = type.createFieldValue();
+ _tensorType = convertToCompatibleType(Identifiable::cast<const TensorDataType &>(type));
+ auto tensor = _tensorType->createFieldValue();
if (tensor->inherits(TensorFieldValue::classId)) {
_tensor.reset(static_cast<TensorFieldValue *>(tensor.release()));
} else {
diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h
index c937e356ce4..c2d61d3e69b 100644
--- a/document/src/vespa/document/update/tensor_modify_update.h
+++ b/document/src/vespa/document/update/tensor_modify_update.h
@@ -6,6 +6,7 @@ namespace vespalib::tensor { class Tensor; }
namespace document {
+class TensorDataType;
class TensorFieldValue;
/*
@@ -25,13 +26,14 @@ public:
};
private:
Operation _operation;
+ std::unique_ptr<const TensorDataType> _tensorType;
std::unique_ptr<TensorFieldValue> _tensor;
TensorModifyUpdate();
TensorModifyUpdate(const TensorModifyUpdate &rhs);
ACCEPT_UPDATE_VISITOR;
public:
- TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> &&tensor);
+ TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor);
~TensorModifyUpdate() override;
TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs);
TensorModifyUpdate &operator=(TensorModifyUpdate &&rhs);