diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-10-16 08:59:58 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-10-16 08:59:58 +0000 |
commit | 12460f8ef17572800c268001205f621e6a0aaf9c (patch) | |
tree | 8a59bf56cbbb9e18d59031e88f23ddd85ac582bf /document/src | |
parent | a4e825b9d0da0143700b253b47b01fe79c635684 (diff) |
use compatibility code in TensorPartialUpdate
Diffstat (limited to 'document/src')
3 files changed, 7 insertions, 69 deletions
diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index 9a89c9850e6..3ae599f22a0 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -82,32 +82,13 @@ TensorAddUpdate::checkCompatibility(const Field& field) const } } -namespace { - -std::unique_ptr<vespalib::eval::Value> -old_add(const vespalib::eval::Value *input, - const vespalib::eval::Value *add_cells) -{ - auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input); - assert(a); - auto b = dynamic_cast<const vespalib::tensor::Tensor *>(add_cells); - assert(b); - return a->add(*b); -} - -} // namespace - std::unique_ptr<vespalib::eval::Value> TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const { auto addTensor = _tensor->getAsTensorPtr(); if (addTensor) { auto engine = EngineOrFactory::get(); - if (engine.is_factory()) { - return TensorPartialUpdate::add(tensor, *addTensor, engine.factory()); - } else { - return old_add(&tensor, addTensor); - } + return TensorPartialUpdate::add(tensor, *addTensor, engine); } return {}; } diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index 044c5a14298..2ff45b11b07 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -160,30 +160,13 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const } } - -std::unique_ptr<vespalib::eval::Value> -old_modify(const vespalib::eval::Value *input, - const vespalib::eval::Value *modify_spec, - join_fun_t function) -{ - auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input); - // Cells tensor being sparse was validated during deserialize(). - auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(modify_spec); - vespalib::tensor::CellValues cellValues(*b); - return a->modify(function, cellValues); -} - std::unique_ptr<vespalib::eval::Value> TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const { auto cellsTensor = _tensor->getAsTensorPtr(); if (cellsTensor) { auto engine = EngineOrFactory::get(); - if (engine.is_factory()) { - return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, engine.factory()); - } else { - return old_modify(&tensor, cellsTensor, getJoinFunction(_operation)); - } + return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, engine); } return {}; } @@ -233,14 +216,8 @@ verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor) return; } auto engine = EngineOrFactory::get(); - if (engine.is_factory()) { - if (cellsTensor->type().is_sparse()) { - return; - } - } else { - if (dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) { - return; - } + if (TensorPartialUpdate::check_suitably_sparse(*cellsTensor, engine)) { + return; } vespalib::string err = make_string("Expected cells tensor to be sparse, but has type '%s'", cellsTensor->type().to_spec().c_str()); diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 1270e430750..5d85b8956fa 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -40,16 +40,6 @@ convertToCompatibleType(const TensorDataType &tensorType) return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type())); } -std::unique_ptr<vespalib::eval::Value> -old_remove(const vespalib::eval::Value *input, - const vespalib::eval::Value *remove_spec) -{ - auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input); - auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(remove_spec); - vespalib::tensor::CellValues cellAddresses(*b); - return a->remove(cellAddresses); -} - } IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate); @@ -123,11 +113,7 @@ TensorRemoveUpdate::applyTo(const vespalib::eval::Value &tensor) const auto addressTensor = _tensor->getAsTensorPtr(); if (addressTensor) { auto engine = EngineOrFactory::get(); - if (engine.is_factory()) { - return TensorPartialUpdate::remove(tensor, *addressTensor, engine.factory()); - } else { - return old_remove(&tensor, addressTensor); - } + return TensorPartialUpdate::remove(tensor, *addressTensor, engine); } return {}; } @@ -177,14 +163,8 @@ verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor) return; } auto engine = EngineOrFactory::get(); - if (engine.is_factory()) { - if (addressTensor->type().is_sparse()) { - return; - } - } else { - if (dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) { - return; - } + if (TensorPartialUpdate::check_suitably_sparse(*addressTensor, engine)) { + return; } vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'", addressTensor->type().to_spec().c_str()); |