aboutsummaryrefslogtreecommitdiffstats
path: root/document/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-16 08:59:58 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-16 08:59:58 +0000
commit12460f8ef17572800c268001205f621e6a0aaf9c (patch)
tree8a59bf56cbbb9e18d59031e88f23ddd85ac582bf /document/src
parenta4e825b9d0da0143700b253b47b01fe79c635684 (diff)
use compatibility code in TensorPartialUpdate
Diffstat (limited to 'document/src')
-rw-r--r--document/src/vespa/document/update/tensor_add_update.cpp21
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp29
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp26
3 files changed, 7 insertions, 69 deletions
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp
index 9a89c9850e6..3ae599f22a0 100644
--- a/document/src/vespa/document/update/tensor_add_update.cpp
+++ b/document/src/vespa/document/update/tensor_add_update.cpp
@@ -82,32 +82,13 @@ TensorAddUpdate::checkCompatibility(const Field& field) const
}
}
-namespace {
-
-std::unique_ptr<vespalib::eval::Value>
-old_add(const vespalib::eval::Value *input,
- const vespalib::eval::Value *add_cells)
-{
- auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input);
- assert(a);
- auto b = dynamic_cast<const vespalib::tensor::Tensor *>(add_cells);
- assert(b);
- return a->add(*b);
-}
-
-} // namespace
-
std::unique_ptr<vespalib::eval::Value>
TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const
{
auto addTensor = _tensor->getAsTensorPtr();
if (addTensor) {
auto engine = EngineOrFactory::get();
- if (engine.is_factory()) {
- return TensorPartialUpdate::add(tensor, *addTensor, engine.factory());
- } else {
- return old_add(&tensor, addTensor);
- }
+ return TensorPartialUpdate::add(tensor, *addTensor, engine);
}
return {};
}
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 044c5a14298..2ff45b11b07 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -160,30 +160,13 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const
}
}
-
-std::unique_ptr<vespalib::eval::Value>
-old_modify(const vespalib::eval::Value *input,
- const vespalib::eval::Value *modify_spec,
- join_fun_t function)
-{
- auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input);
- // Cells tensor being sparse was validated during deserialize().
- auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(modify_spec);
- vespalib::tensor::CellValues cellValues(*b);
- return a->modify(function, cellValues);
-}
-
std::unique_ptr<vespalib::eval::Value>
TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const
{
auto cellsTensor = _tensor->getAsTensorPtr();
if (cellsTensor) {
auto engine = EngineOrFactory::get();
- if (engine.is_factory()) {
- return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, engine.factory());
- } else {
- return old_modify(&tensor, cellsTensor, getJoinFunction(_operation));
- }
+ return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, engine);
}
return {};
}
@@ -233,14 +216,8 @@ verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor)
return;
}
auto engine = EngineOrFactory::get();
- if (engine.is_factory()) {
- if (cellsTensor->type().is_sparse()) {
- return;
- }
- } else {
- if (dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) {
- return;
- }
+ if (TensorPartialUpdate::check_suitably_sparse(*cellsTensor, engine)) {
+ return;
}
vespalib::string err = make_string("Expected cells tensor to be sparse, but has type '%s'",
cellsTensor->type().to_spec().c_str());
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 1270e430750..5d85b8956fa 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -40,16 +40,6 @@ convertToCompatibleType(const TensorDataType &tensorType)
return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type()));
}
-std::unique_ptr<vespalib::eval::Value>
-old_remove(const vespalib::eval::Value *input,
- const vespalib::eval::Value *remove_spec)
-{
- auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input);
- auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(remove_spec);
- vespalib::tensor::CellValues cellAddresses(*b);
- return a->remove(cellAddresses);
-}
-
}
IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate);
@@ -123,11 +113,7 @@ TensorRemoveUpdate::applyTo(const vespalib::eval::Value &tensor) const
auto addressTensor = _tensor->getAsTensorPtr();
if (addressTensor) {
auto engine = EngineOrFactory::get();
- if (engine.is_factory()) {
- return TensorPartialUpdate::remove(tensor, *addressTensor, engine.factory());
- } else {
- return old_remove(&tensor, addressTensor);
- }
+ return TensorPartialUpdate::remove(tensor, *addressTensor, engine);
}
return {};
}
@@ -177,14 +163,8 @@ verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor)
return;
}
auto engine = EngineOrFactory::get();
- if (engine.is_factory()) {
- if (addressTensor->type().is_sparse()) {
- return;
- }
- } else {
- if (dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) {
- return;
- }
+ if (TensorPartialUpdate::check_suitably_sparse(*addressTensor, engine)) {
+ return;
}
vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'",
addressTensor->type().to_spec().c_str());