diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-10-13 14:11:28 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-10-15 08:18:59 +0000 |
commit | 08393e9e14635f1c6a6c84650c25023a0db7ed0b (patch) | |
tree | 48aae1605140fc6ff7d571084f345d33a3189c62 /document/src | |
parent | 61eaea251e8cacd320ac10754ffd1513d8638043 (diff) |
handle both engine- and factory-based tensors
* use EngineOrFactory::get() instead of DefaultTensorEngine::ref()
* avoid direct use of DenseTensorView etc where possible
* use eval::Value instead of tensor::Tensor where possible
Diffstat (limited to 'document/src')
13 files changed, 182 insertions, 103 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 9d2567e93ed..ca519a2f7d0 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -21,8 +21,9 @@ #include <vespa/document/update/tensor_remove_update.h> #include <vespa/document/update/valueupdate.h> #include <vespa/document/util/bytebuffer.h> -#include <vespa/eval/tensor/default_tensor_engine.h> -#include <vespa/eval/tensor/tensor.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/test/value_compare.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/exception.h> #include <vespa/vespalib/util/exceptions.h> @@ -33,10 +34,10 @@ #include <unistd.h> using namespace document::config_builder; + using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; -using vespalib::tensor::DefaultTensorEngine; -using vespalib::tensor::Tensor; +using vespalib::eval::EngineOrFactory; using vespalib::nbostream; namespace document { @@ -771,11 +772,12 @@ TEST(DocumentUpdateTest, testMapValueUpdate) EXPECT_EQ(fv4->find(StringFieldValue("apple")), fv4->end()); } -std::unique_ptr<Tensor> +std::unique_ptr<vespalib::eval::Value> makeTensor(const TensorSpec &spec) { - auto result = DefaultTensorEngine::ref().from_spec(spec); - return std::unique_ptr<Tensor>(dynamic_cast<Tensor*>(result.release())); + auto result = EngineOrFactory::get().from_spec(spec); + EXPECT_TRUE(result->is_tensor()); + return result; } std::unique_ptr<TensorFieldValue> @@ -787,7 +789,7 @@ makeTensorFieldValue(const TensorSpec &spec, const TensorDataType &dataType) return result; } -const Tensor &asTensor(const FieldValue &fieldValue) { +const vespalib::eval::Value &asTensor(const FieldValue &fieldValue) { auto &tensorFieldValue = dynamic_cast<const TensorFieldValue &>(fieldValue); auto tensor = tensorFieldValue.getAsTensorPtr(); assert(tensor); diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp index 02f170cd5f1..13d5e7d8405 100644 --- a/document/src/tests/serialization/vespadocumentserializer_test.cpp +++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp @@ -36,8 +36,9 @@ #include <vespa/document/serialization/vespadocumentdeserializer.h> #include <vespa/document/serialization/vespadocumentserializer.h> #include <vespa/document/serialization/annotationserializer.h> -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/test/value_compare.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/testkit/testapp.h> @@ -50,8 +51,7 @@ using vespalib::nbostream; using vespalib::nbostream_longlivedbuf; using vespalib::slime::Cursor; using vespalib::eval::TensorSpec; -using vespalib::tensor::Tensor; -using vespalib::tensor::DefaultTensorEngine; +using vespalib::eval::EngineOrFactory; using vespalib::compression::CompressionConfig; using namespace document; using std::string; @@ -771,12 +771,10 @@ TEST("Require that predicate deserialization matches Java") { namespace { -Tensor::UP createTensor(const TensorSpec &spec) { - auto value = DefaultTensorEngine::ref().from_spec(spec); - Tensor *tensor = dynamic_cast<Tensor*>(value.get()); - ASSERT_TRUE(tensor != nullptr); - value.release(); - return Tensor::UP(tensor); +vespalib::eval::Value::UP createTensor(const TensorSpec &spec) { + auto value = EngineOrFactory::get().from_spec(spec); + ASSERT_TRUE(value->is_tensor()); + return value; } } @@ -836,13 +834,13 @@ void deserializeAndCheck(const string &file_name, TensorFieldValue &value) { deserializeAndCheck(file_name, value, tensor_repo, tensor_field_name); } -void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) { +void checkDeserialization(const string &name, std::unique_ptr<vespalib::eval::Value> tensor) { const string data_dir = TEST_PATH("../../test/resources/tensor/"); TensorDataType valueType(tensor ? tensor->type() : vespalib::eval::ValueType::error_type()); TensorFieldValue value(valueType); if (tensor) { - value = tensor->clone(); + value = EngineOrFactory::get().copy(*tensor); } serializeToFile(value, data_dir + name + "__cpp"); deserializeAndCheck(data_dir + name + "__cpp", value); @@ -851,7 +849,7 @@ void checkDeserialization(const string &name, std::unique_ptr<Tensor> tensor) { TEST("Require that tensor deserialization matches Java") { - checkDeserialization("non_existing_tensor", std::unique_ptr<Tensor>()); + checkDeserialization("non_existing_tensor", std::unique_ptr<vespalib::eval::Value>()); checkDeserialization("empty_tensor", createTensor(TensorSpec("tensor(dimX{},dimY{})"))); checkDeserialization("multi_cell_tensor", createTensor(TensorSpec("tensor(dimX{},dimY{})") @@ -863,17 +861,17 @@ TEST("Require that tensor deserialization matches Java") { struct TensorDocFixture { const DocumentTypeRepo &_docTypeRepo; const DocumentType *_docType; - std::unique_ptr<Tensor> _tensor; + std::unique_ptr<vespalib::eval::Value> _tensor; Document _doc; vespalib::nbostream _blob; TensorDocFixture(const DocumentTypeRepo &docTypeRepo, - std::unique_ptr<Tensor> tensor); + std::unique_ptr<vespalib::eval::Value> tensor); ~TensorDocFixture(); }; TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo, - std::unique_ptr<Tensor> tensor) + std::unique_ptr<vespalib::eval::Value> tensor) : _docTypeRepo(docTypeRepo), _docType(_docTypeRepo.getDocumentType(tensor_doc_type_id)), _tensor(std::move(tensor)), @@ -881,7 +879,7 @@ TensorDocFixture::TensorDocFixture(const DocumentTypeRepo &docTypeRepo, _blob() { auto fv = _doc.getField(tensor_field_name).createValue(); - dynamic_cast<TensorFieldValue &>(*fv) = _tensor->clone(); + dynamic_cast<TensorFieldValue &>(*fv) = EngineOrFactory::get().copy(*_tensor); _doc.setValue(tensor_field_name, *fv); _doc.serialize(_blob); } @@ -897,7 +895,7 @@ struct DeserializedTensorDoc ~DeserializedTensorDoc(); void setup(const DocumentTypeRepo &docTypeRepo, const vespalib::nbostream &blob); - const Tensor *getTensor() const; + const vespalib::eval::Value *getTensor() const; }; DeserializedTensorDoc::DeserializedTensorDoc() @@ -916,7 +914,7 @@ DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib _fieldValue = _doc->getValue(tensor_field_name); } -const Tensor * +const vespalib::eval::Value * DeserializedTensorDoc::getTensor() const { return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr(); @@ -936,14 +934,14 @@ TEST("Require that wrong tensor type hides tensor") DeserializedTensorDoc doc; doc.setup(tensor_doc_repo, f._blob); EXPECT_TRUE(doc.getTensor() != nullptr); - EXPECT_TRUE(doc.getTensor()->equals(*f._tensor)); + EXPECT_TRUE((*doc.getTensor()) == (*f._tensor)); doc.setup(tensor_doc_repo, f1._blob); EXPECT_TRUE(doc.getTensor() == nullptr); doc.setup(tensor_doc_repo1, f._blob); EXPECT_TRUE(doc.getTensor() == nullptr); doc.setup(tensor_doc_repo1, f1._blob); EXPECT_TRUE(doc.getTensor() != nullptr); - EXPECT_TRUE(doc.getTensor()->equals(*f1._tensor)); + EXPECT_TRUE((*doc.getTensor()) == (*f1._tensor)); } struct RefFixture { diff --git a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp index 9d2da9c983a..18afdb15bb8 100644 --- a/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp +++ b/document/src/tests/tensor_fieldvalue/tensor_fieldvalue_test.cpp @@ -7,9 +7,8 @@ LOG_SETUP("fieldvalue_test"); #include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/types.h> -#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/eval/value.h> #include <vespa/eval/tensor/test/test_utils.h> #include <vespa/vespalib/testkit/testapp.h> @@ -18,7 +17,7 @@ using namespace document; using namespace vespalib::tensor; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; -using vespalib::tensor::DefaultTensorEngine; +using vespalib::eval::EngineOrFactory; using vespalib::tensor::test::makeTensor; namespace @@ -27,19 +26,17 @@ namespace TensorDataType xSparseTensorDataType(ValueType::from_spec("tensor(x{})")); TensorDataType xySparseTensorDataType(ValueType::from_spec("tensor(x{},y{})")); -Tensor::UP createTensor(const TensorSpec &spec) { - auto value = DefaultTensorEngine::ref().from_spec(spec); - Tensor *tensor = dynamic_cast<Tensor*>(value.get()); - ASSERT_TRUE(tensor != nullptr); - value.release(); - return Tensor::UP(tensor); +vespalib::eval::Value::UP createTensor(const TensorSpec &spec) { + auto value = EngineOrFactory::get().from_spec(spec); + ASSERT_TRUE(value->is_tensor()); + return value; } -std::unique_ptr<Tensor> +std::unique_ptr<vespalib::eval::Value> makeSimpleTensor() { - return makeTensor<Tensor>(TensorSpec("tensor(x{},y{})"). - add({{"x", "4"}, {"y", "5"}}, 7)); + return makeTensor<vespalib::eval::Value>(TensorSpec("tensor(x{},y{})"). + add({{"x", "4"}, {"y", "5"}}, 7)); } FieldValue::UP clone(FieldValue &fv) { diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp index c3b593732b9..2a66ea61966 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp @@ -6,16 +6,14 @@ #include <vespa/vespalib/util/xmlstream.h> #include <vespa/eval/eval/engine_or_factory.h> #include <vespa/eval/eval/tensor_spec.h> -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/engine_or_factory.h> #include <ostream> #include <cassert> -using vespalib::tensor::Tensor; using vespalib::eval::EngineOrFactory; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; -using Engine = vespalib::tensor::DefaultTensorEngine; using namespace vespalib::xml; namespace document { @@ -53,7 +51,7 @@ TensorFieldValue::TensorFieldValue(const TensorFieldValue &rhs) _altered(true) { if (rhs._tensor) { - _tensor = rhs._tensor->clone(); + _tensor = EngineOrFactory::get().copy(*rhs._tensor); } } @@ -80,7 +78,7 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs) if (&_dataType == &rhs._dataType || !rhs._tensor || _dataType.isAssignableType(rhs._tensor->type())) { if (rhs._tensor) { - _tensor = rhs._tensor->clone(); + _tensor = EngineOrFactory::get().copy(*rhs._tensor); } else { _tensor.reset(); } @@ -94,7 +92,7 @@ TensorFieldValue::operator=(const TensorFieldValue &rhs) TensorFieldValue & -TensorFieldValue::operator=(std::unique_ptr<Tensor> rhs) +TensorFieldValue::operator=(std::unique_ptr<vespalib::eval::Value> rhs) { if (!rhs || _dataType.isAssignableType(rhs->type())) { _tensor = std::move(rhs); @@ -111,11 +109,7 @@ TensorFieldValue::make_empty_if_not_existing() { if (!_tensor) { TensorSpec empty_spec(_dataType.getTensorType().to_spec()); - auto empty_value = Engine::ref().from_spec(empty_spec); - auto tensor_ptr = dynamic_cast<Tensor*>(empty_value.get()); - assert(tensor_ptr != nullptr); - _tensor.reset(tensor_ptr); - empty_value.release(); + _tensor = EngineOrFactory::get().from_spec(empty_spec); } } @@ -163,7 +157,7 @@ TensorFieldValue::print(std::ostream& out, bool verbose, (void) indent; out << "{TensorFieldValue: "; if (_tensor) { - out << Engine::ref().to_spec(*_tensor).to_string(); + out << EngineOrFactory::get().to_spec(*_tensor).to_string(); } else { out << "null"; } @@ -192,7 +186,7 @@ TensorFieldValue::assign(const FieldValue &value) void -TensorFieldValue::assignDeserialized(std::unique_ptr<Tensor> rhs) +TensorFieldValue::assignDeserialized(std::unique_ptr<vespalib::eval::Value> rhs) { if (!rhs || _dataType.isAssignableType(rhs->type())) { _tensor = std::move(rhs); diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h index 30cc10558b5..82a10e8aaa6 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h @@ -5,6 +5,7 @@ #include "fieldvalue.h" namespace vespalib { namespace tensor { class Tensor; } } +namespace vespalib::eval { class Value; } namespace document { @@ -16,7 +17,7 @@ class TensorDataType; class TensorFieldValue : public FieldValue { private: const TensorDataType &_dataType; - std::unique_ptr<vespalib::tensor::Tensor> _tensor; + std::unique_ptr<vespalib::eval::Value> _tensor; bool _altered; public: TensorFieldValue(); @@ -26,7 +27,7 @@ public: ~TensorFieldValue(); TensorFieldValue &operator=(const TensorFieldValue &rhs); - TensorFieldValue &operator=(std::unique_ptr<vespalib::tensor::Tensor> rhs); + TensorFieldValue &operator=(std::unique_ptr<vespalib::eval::Value> rhs); void make_empty_if_not_existing(); @@ -39,10 +40,10 @@ public: const std::string& indent) const override; virtual void printXml(XmlOutputStream& out) const override; virtual FieldValue &assign(const FieldValue &value) override; - const vespalib::tensor::Tensor *getAsTensorPtr() const { + const vespalib::eval::Value *getAsTensorPtr() const { return _tensor.get(); } - void assignDeserialized(std::unique_ptr<vespalib::tensor::Tensor> rhs); + void assignDeserialized(std::unique_ptr<vespalib::eval::Value> rhs); virtual int compare(const FieldValue& other) const override; DECLARE_IDENTIFIABLE(TensorFieldValue); diff --git a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp index 94644438f5c..6ec9c52281f 100644 --- a/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentdeserializer.cpp @@ -22,8 +22,8 @@ #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/backtrace.h> -#include <vespa/eval/tensor/tensor.h> -#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/eval/value.h> #include <vespa/document/util/serializableexceptions.h> #include <vespa/document/base/exceptions.h> #include <vespa/vespalib/objects/nbostream.h> @@ -41,6 +41,7 @@ using vespalib::nbostream; using vespalib::Memory; using vespalib::stringref; using vespalib::compression::CompressionConfig; +using vespalib::eval::EngineOrFactory; namespace document { @@ -363,10 +364,10 @@ VespaDocumentDeserializer::read(TensorFieldValue &value) throw DeserializeException(vespalib::make_string("Stream failed size(%zu), needed(%zu) to deserialize tensor field value", _stream.size(), length), VESPA_STRLOC); } - std::unique_ptr<vespalib::tensor::Tensor> tensor; + std::unique_ptr<vespalib::eval::Value> tensor; if (length != 0) { nbostream wrapStream(_stream.peek(), length); - tensor = vespalib::tensor::TypedBinaryFormat::deserialize(wrapStream); + tensor = EngineOrFactory::get().decode(wrapStream); if (wrapStream.size() != 0) { throw DeserializeException("Leftover bytes deserializing tensor field value.", VESPA_STRLOC); } diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp index 6d9c08578e7..882dc4e83f3 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp @@ -26,7 +26,8 @@ #include <vespa/document/update/fieldpathupdates.h> #include <vespa/document/update/updates.h> #include <vespa/document/util/bytebuffer.h> -#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/engine_or_factory.h> #include <vespa/vespalib/data/databuffer.h> #include <vespa/vespalib/data/slime/binary_format.h> #include <vespa/vespalib/objects/nbostream.h> @@ -370,7 +371,7 @@ VespaDocumentSerializer::write(const TensorFieldValue &value) { vespalib::nbostream tmpStream; auto tensor = value.getAsTensorPtr(); if (tensor) { - vespalib::tensor::TypedBinaryFormat::serialize(tmpStream, *tensor); + vespalib::eval::EngineOrFactory::get().encode(*tensor, tmpStream); assert( ! tmpStream.empty()); _stream.putInt1_4Bytes(tmpStream.size()); _stream.write(tmpStream.peek(), tmpStream.size()); diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index d9bec7762b6..91b72329994 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -8,6 +8,9 @@ #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> #include <vespa/document/util/serializableexceptions.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/tensor/partial_update.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/stllike/asciistream.h> @@ -17,8 +20,9 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; -using vespalib::tensor::Tensor; using vespalib::make_string; +using vespalib::eval::EngineOrFactory; +using vespalib::tensor::TensorPartialUpdate; namespace document { @@ -78,14 +82,34 @@ TensorAddUpdate::checkCompatibility(const Field& field) const } } -std::unique_ptr<Tensor> -TensorAddUpdate::applyTo(const Tensor &tensor) const +namespace { + +std::unique_ptr<vespalib::eval::Value> +old_add(const vespalib::eval::Value *input, + const vespalib::eval::Value *add_cells) +{ + auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input); + assert(a); + auto b = dynamic_cast<const vespalib::tensor::Tensor *>(add_cells); + assert(b); + return a->add(*b); +} + +} // namespace + +std::unique_ptr<vespalib::eval::Value> +TensorAddUpdate::applyTo(const vespalib::eval::Value &tensor) const { auto addTensor = _tensor->getAsTensorPtr(); if (addTensor) { - return tensor.add(*addTensor); + auto engine = EngineOrFactory::get(); + if (engine.is_factory()) { + return TensorPartialUpdate::add(tensor, *addTensor, engine.factory()); + } else { + return old_add(&tensor, addTensor); + } } - return std::unique_ptr<Tensor>(); + return std::unique_ptr<vespalib::eval::Value>(); } bool diff --git a/document/src/vespa/document/update/tensor_add_update.h b/document/src/vespa/document/update/tensor_add_update.h index 49519ee1ddd..8687967be49 100644 --- a/document/src/vespa/document/update/tensor_add_update.h +++ b/document/src/vespa/document/update/tensor_add_update.h @@ -2,7 +2,7 @@ #include "valueupdate.h" -namespace vespalib::tensor { class Tensor; } +namespace vespalib::eval { class Value; } namespace document { @@ -27,7 +27,7 @@ public: bool operator==(const ValueUpdate &other) const override; const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; - std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const; + std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index 5fbdc2467b3..292b4165540 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -9,8 +9,11 @@ #include <vespa/document/serialization/vespadocumentdeserializer.h> #include <vespa/document/util/serializableexceptions.h> #include <vespa/eval/eval/operation.h> -#include <vespa/eval/tensor/cell_values.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/tensor/partial_update.h> #include <vespa/eval/tensor/tensor.h> +#include <vespa/eval/tensor/cell_values.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/stringfmt.h> @@ -19,9 +22,10 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; -using vespalib::tensor::Tensor; using vespalib::make_string; using vespalib::eval::ValueType; +using vespalib::eval::EngineOrFactory; +using vespalib::tensor::TensorPartialUpdate; using join_fun_t = double (*)(double, double); @@ -156,16 +160,32 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const } } -std::unique_ptr<Tensor> -TensorModifyUpdate::applyTo(const Tensor &tensor) const + +std::unique_ptr<vespalib::eval::Value> +old_modify(const vespalib::eval::Value *input, + const vespalib::eval::Value *modify_spec, + join_fun_t function) +{ + auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input); + // Cells tensor being sparse was validated during deserialize(). + auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(modify_spec); + vespalib::tensor::CellValues cellValues(*b); + return a->modify(function, cellValues); +} + +std::unique_ptr<vespalib::eval::Value> +TensorModifyUpdate::applyTo(const vespalib::eval::Value &tensor) const { auto cellsTensor = _tensor->getAsTensorPtr(); if (cellsTensor) { - // Cells tensor being sparse was validated during deserialize(). - vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellsTensor)); - return tensor.modify(getJoinFunction(_operation), cellValues); + auto engine = EngineOrFactory::get(); + if (engine.is_factory()) { + return TensorPartialUpdate::modify(tensor, getJoinFunction(_operation), *cellsTensor, engine.factory()); + } else { + return old_modify(&tensor, cellsTensor, getJoinFunction(_operation)); + } } - return std::unique_ptr<Tensor>(); + return std::unique_ptr<vespalib::eval::Value>(); } bool @@ -207,13 +227,24 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in namespace { void -verifyCellsTensorIsSparse(const Tensor *cellsTensor) +verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor) { - if (cellsTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) { - vespalib::string err = make_string("Expected cell values tensor to be sparse, but has type '%s'", - cellsTensor->type().to_spec().c_str()); - throw IllegalStateException(err, VESPA_STRLOC); + if (cellsTensor == nullptr) { + return; + } + auto engine = EngineOrFactory::get(); + if (engine.is_factory()) { + if (cellsTensor->type().is_sparse()) { + return; + } + } else { + if (dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) { + return; + } } + vespalib::string err = make_string("Expected cells tensor to be sparse, but has type '%s'", + cellsTensor->type().to_spec().c_str()); + throw IllegalStateException(err, VESPA_STRLOC); } } diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h index c2d61d3e69b..528ff8c95e9 100644 --- a/document/src/vespa/document/update/tensor_modify_update.h +++ b/document/src/vespa/document/update/tensor_modify_update.h @@ -2,7 +2,7 @@ #include "valueupdate.h" -namespace vespalib::tensor { class Tensor; } +namespace vespalib::eval { class Value; } namespace document { @@ -41,7 +41,7 @@ public: Operation getOperation() const { return _operation; } const TensorFieldValue &getTensor() const { return *_tensor; } void checkCompatibility(const Field &field) const override; - std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const; + std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; bool applyTo(FieldValue &value) const override; void printXml(XmlOutputStream &xos) const override; void print(std::ostream &out, bool verbose, const std::string &indent) const override; diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 34a6223e185..178bd1bd950 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -6,18 +6,23 @@ #include <vespa/document/fieldvalue/document.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/tensor/partial_update.h> +#include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/cell_values.h> #include <vespa/eval/tensor/sparse/sparse_tensor.h> -#include <vespa/eval/tensor/tensor.h> +#include <vespa/eval/eval/value.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/xmlstream.h> #include <ostream> +#include <cassert> using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; -using vespalib::tensor::Tensor; using vespalib::make_string; using vespalib::eval::ValueType; +using vespalib::eval::EngineOrFactory; +using vespalib::tensor::TensorPartialUpdate; namespace document { @@ -35,6 +40,16 @@ convertToCompatibleType(const TensorDataType &tensorType) return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type())); } +std::unique_ptr<vespalib::eval::Value> +old_remove(const vespalib::eval::Value *input, + const vespalib::eval::Value *remove_spec) +{ + auto a = dynamic_cast<const vespalib::tensor::Tensor *>(input); + auto b = dynamic_cast<const vespalib::tensor::SparseTensor *>(remove_spec); + vespalib::tensor::CellValues cellAddresses(*b); + return a->remove(cellAddresses); +} + } IMPLEMENT_IDENTIFIABLE(TensorRemoveUpdate, ValueUpdate); @@ -102,16 +117,19 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const } } -std::unique_ptr<Tensor> -TensorRemoveUpdate::applyTo(const Tensor &tensor) const +std::unique_ptr<vespalib::eval::Value> +TensorRemoveUpdate::applyTo(const vespalib::eval::Value &tensor) const { auto addressTensor = _tensor->getAsTensorPtr(); if (addressTensor) { - // Address tensor being sparse was validated during deserialize(). - vespalib::tensor::CellValues cellAddresses(static_cast<const vespalib::tensor::SparseTensor &>(*addressTensor)); - return tensor.remove(cellAddresses); + auto engine = EngineOrFactory::get(); + if (engine.is_factory()) { + return TensorPartialUpdate::remove(tensor, *addressTensor, engine.factory()); + } else { + return old_remove(&tensor, addressTensor); + } } - return std::unique_ptr<Tensor>(); + return std::unique_ptr<vespalib::eval::Value>(); } bool @@ -153,15 +171,27 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in namespace { void -verifyAddressTensorIsSparse(const Tensor *addressTensor) +verifyAddressTensorIsSparse(const vespalib::eval::Value *addressTensor) { - if (addressTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) { - vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'", - addressTensor->type().to_spec().c_str()); - throw IllegalStateException(err, VESPA_STRLOC); + if (addressTensor == nullptr) { + return; + } + auto engine = EngineOrFactory::get(); + if (engine.is_factory()) { + if (addressTensor->type().is_sparse()) { + return; + } + } else { + if (dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) { + return; + } } + vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'", + addressTensor->type().to_spec().c_str()); + throw IllegalStateException(err, VESPA_STRLOC); } + } void diff --git a/document/src/vespa/document/update/tensor_remove_update.h b/document/src/vespa/document/update/tensor_remove_update.h index e75348fa829..6ab66048dd4 100644 --- a/document/src/vespa/document/update/tensor_remove_update.h +++ b/document/src/vespa/document/update/tensor_remove_update.h @@ -2,7 +2,7 @@ #include "valueupdate.h" -namespace vespalib::tensor { class Tensor; } +namespace vespalib::eval { class Value; } namespace document { @@ -30,7 +30,7 @@ public: TensorRemoveUpdate &operator=(const TensorRemoveUpdate &rhs); TensorRemoveUpdate &operator=(TensorRemoveUpdate &&rhs); const TensorFieldValue &getTensor() const { return *_tensor; } - std::unique_ptr<vespalib::tensor::Tensor> applyTo(const vespalib::tensor::Tensor &tensor) const; + std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; bool operator==(const ValueUpdate &other) const override; void checkCompatibility(const Field &field) const override; |