diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2020-10-14 20:03:02 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-14 20:03:02 +0200 |
commit | 89ae1507eb41fc760fc93e3d341db24babccbe44 (patch) | |
tree | bf52be62ad94224ac9af19565da17199bd72cdd0 | |
parent | 29e5303e5c89ff50b10661a430bfd52d41d09774 (diff) | |
parent | 3d89b1eaf56c686a9e795f3a544050173ee503e3 (diff) |
Merge pull request #14869 from vespa-engine/arnej/minor-document-cleanup
Arnej/minor document cleanup
10 files changed, 37 insertions, 24 deletions
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 18001c35da5..9d2567e93ed 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -789,7 +789,7 @@ makeTensorFieldValue(const TensorSpec &spec, const TensorDataType &dataType) const Tensor &asTensor(const FieldValue &fieldValue) { auto &tensorFieldValue = dynamic_cast<const TensorFieldValue &>(fieldValue); - auto &tensor = tensorFieldValue.getAsTensorPtr(); + auto tensor = tensorFieldValue.getAsTensorPtr(); assert(tensor); return *tensor; } @@ -876,7 +876,7 @@ struct TensorUpdateFixture { auto field = getTensor(); auto tensor_field = dynamic_cast<TensorFieldValue*>(field.get()); ASSERT_TRUE(tensor_field); - EXPECT_TRUE(tensor_field->getAsTensorPtr().get() == nullptr); + EXPECT_TRUE(tensor_field->getAsTensorPtr() == nullptr); } void assertTensor(const TensorSpec &expSpec) { diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp index c0ebdad6373..02f170cd5f1 100644 --- a/document/src/tests/serialization/vespadocumentserializer_test.cpp +++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp @@ -919,7 +919,7 @@ DeserializedTensorDoc::setup(const DocumentTypeRepo &docTypeRepo, const vespalib const Tensor * DeserializedTensorDoc::getTensor() const { - return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr().get(); + return dynamic_cast<const TensorFieldValue &>(*_fieldValue).getAsTensorPtr(); } TEST("Require that wrong tensor type hides tensor") diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp index 56d7b6ab078..c3b593732b9 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.cpp @@ -4,6 +4,7 @@ #include <vespa/document/base/exceptions.h> #include <vespa/document/datatype/tensor_data_type.h> #include <vespa/vespalib/util/xmlstream.h> +#include <vespa/eval/eval/engine_or_factory.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/default_tensor_engine.h> @@ -11,6 +12,7 @@ #include <cassert> using vespalib::tensor::Tensor; +using vespalib::eval::EngineOrFactory; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; using Engine = vespalib::tensor::DefaultTensorEngine; @@ -218,13 +220,24 @@ TensorFieldValue::compare(const FieldValue &other) const if (!rhs._tensor) { return 1; } - if (_tensor->equals(*rhs._tensor)) { + // equal pointers always means identical + if (_tensor.get() == rhs._tensor.get()) { return 0; } - assert(_tensor.get() != rhs._tensor.get()); - // XXX: Wrong, compares identity of tensors instead of values - // Note: sorting can be dangerous due to this. - return ((_tensor.get() < rhs._tensor.get()) ? -1 : 1); + // compare just the type first: + auto lhs_type = _tensor->type().to_spec(); + auto rhs_type = rhs._tensor->type().to_spec(); + int type_cmp = lhs_type.compare(rhs_type); + if (type_cmp != 0) { + return type_cmp; + } + // Compare the actual tensors by converting to TensorSpec strings. + // TODO: this can be very slow, check if it might be used for anything + // performance-critical. + auto engine = EngineOrFactory::get(); + auto lhs_spec = engine.to_spec(*_tensor).to_string(); + auto rhs_spec = engine.to_spec(*rhs._tensor).to_string(); + return lhs_spec.compare(rhs_spec); } IMPLEMENT_IDENTIFIABLE(TensorFieldValue, FieldValue); diff --git a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h index ea3f8dea9be..30cc10558b5 100644 --- a/document/src/vespa/document/fieldvalue/tensorfieldvalue.h +++ b/document/src/vespa/document/fieldvalue/tensorfieldvalue.h @@ -39,8 +39,8 @@ public: const std::string& indent) const override; virtual void printXml(XmlOutputStream& out) const override; virtual FieldValue &assign(const FieldValue &value) override; - const std::unique_ptr<vespalib::tensor::Tensor> &getAsTensorPtr() const { - return _tensor; + const vespalib::tensor::Tensor *getAsTensorPtr() const { + return _tensor.get(); } void assignDeserialized(std::unique_ptr<vespalib::tensor::Tensor> rhs); virtual int compare(const FieldValue& other) const override; diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp index eadbd4b5a8a..6d9c08578e7 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp @@ -368,7 +368,7 @@ VespaDocumentSerializer::write(const WeightedSetFieldValue &value) { void VespaDocumentSerializer::write(const TensorFieldValue &value) { vespalib::nbostream tmpStream; - auto &tensor = value.getAsTensorPtr(); + auto tensor = value.getAsTensorPtr(); if (tensor) { vespalib::tensor::TypedBinaryFormat::serialize(tmpStream, *tensor); assert( ! tmpStream.empty()); diff --git a/document/src/vespa/document/update/tensor_add_update.cpp b/document/src/vespa/document/update/tensor_add_update.cpp index 2e5fa194c20..d9bec7762b6 100644 --- a/document/src/vespa/document/update/tensor_add_update.cpp +++ b/document/src/vespa/document/update/tensor_add_update.cpp @@ -81,7 +81,7 @@ TensorAddUpdate::checkCompatibility(const Field& field) const std::unique_ptr<Tensor> TensorAddUpdate::applyTo(const Tensor &tensor) const { - auto &addTensor = _tensor->getAsTensorPtr(); + auto addTensor = _tensor->getAsTensorPtr(); if (addTensor) { return tensor.add(*addTensor); } @@ -94,7 +94,7 @@ TensorAddUpdate::applyTo(FieldValue& value) const if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); tensorFieldValue.make_empty_if_not_existing(); - auto &oldTensor = tensorFieldValue.getAsTensorPtr(); + auto oldTensor = tensorFieldValue.getAsTensorPtr(); auto newTensor = applyTo(*oldTensor); if (newTensor) { tensorFieldValue = std::move(newTensor); diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index dfc7479e5cd..5fbdc2467b3 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -159,7 +159,7 @@ TensorModifyUpdate::checkCompatibility(const Field& field) const std::unique_ptr<Tensor> TensorModifyUpdate::applyTo(const Tensor &tensor) const { - auto &cellsTensor = _tensor->getAsTensorPtr(); + auto cellsTensor = _tensor->getAsTensorPtr(); if (cellsTensor) { // Cells tensor being sparse was validated during deserialize(). vespalib::tensor::CellValues cellValues(static_cast<const vespalib::tensor::SparseTensor &>(*cellsTensor)); @@ -173,7 +173,7 @@ TensorModifyUpdate::applyTo(FieldValue& value) const { if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); - auto &oldTensor = tensorFieldValue.getAsTensorPtr(); + auto oldTensor = tensorFieldValue.getAsTensorPtr(); if (oldTensor) { auto newTensor = applyTo(*oldTensor); if (newTensor) { @@ -207,9 +207,9 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in namespace { void -verifyCellsTensorIsSparse(const std::unique_ptr<Tensor> &cellsTensor) +verifyCellsTensorIsSparse(const Tensor *cellsTensor) { - if (cellsTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor.get())) { + if (cellsTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(cellsTensor)) { vespalib::string err = make_string("Expected cell values tensor to be sparse, but has type '%s'", cellsTensor->type().to_spec().c_str()); throw IllegalStateException(err, VESPA_STRLOC); diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 91b4c0a6ca3..34a6223e185 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -105,7 +105,7 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const std::unique_ptr<Tensor> TensorRemoveUpdate::applyTo(const Tensor &tensor) const { - auto &addressTensor = _tensor->getAsTensorPtr(); + auto addressTensor = _tensor->getAsTensorPtr(); if (addressTensor) { // Address tensor being sparse was validated during deserialize(). vespalib::tensor::CellValues cellAddresses(static_cast<const vespalib::tensor::SparseTensor &>(*addressTensor)); @@ -119,7 +119,7 @@ TensorRemoveUpdate::applyTo(FieldValue &value) const { if (value.inherits(TensorFieldValue::classId)) { TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); - auto &oldTensor = tensorFieldValue.getAsTensorPtr(); + auto oldTensor = tensorFieldValue.getAsTensorPtr(); if (oldTensor) { auto newTensor = applyTo(*oldTensor); if (newTensor) { @@ -153,9 +153,9 @@ TensorRemoveUpdate::print(std::ostream &out, bool verbose, const std::string &in namespace { void -verifyAddressTensorIsSparse(const std::unique_ptr<Tensor> &addressTensor) +verifyAddressTensorIsSparse(const Tensor *addressTensor) { - if (addressTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) { + if (addressTensor && !dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor)) { vespalib::string err = make_string("Expected address tensor to be sparse, but has type '%s'", addressTensor->type().to_spec().c_str()); throw IllegalStateException(err, VESPA_STRLOC); diff --git a/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp b/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp index b295926c64a..5f8f7a63dd0 100644 --- a/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp +++ b/searchcore/src/tests/proton/docsummary/summaryfieldconverter_test.cpp @@ -468,8 +468,8 @@ void Test::checkTensor(const Tensor::UP &tensor, const FieldValue *value) { ASSERT_TRUE(value); const TensorFieldValue *s = dynamic_cast<const TensorFieldValue *>(value); ASSERT_TRUE(s); - const Tensor::UP &tvalue = s->getAsTensorPtr(); - EXPECT_EQUAL(tensor.get() != nullptr, tvalue.get() != nullptr); + auto tvalue = s->getAsTensorPtr(); + EXPECT_EQUAL(tensor.get() != nullptr, tvalue != nullptr); if (tensor) { EXPECT_EQUAL(*tensor, *tvalue); } diff --git a/searchcore/src/vespa/searchcore/proton/docsummary/documentstoreadapter.cpp b/searchcore/src/vespa/searchcore/proton/docsummary/documentstoreadapter.cpp index 080aee88f3f..0e6ce5e689b 100644 --- a/searchcore/src/vespa/searchcore/proton/docsummary/documentstoreadapter.cpp +++ b/searchcore/src/vespa/searchcore/proton/docsummary/documentstoreadapter.cpp @@ -86,7 +86,7 @@ DocumentStoreAdapter::writeField(const FieldValue &value, ResType type) vespalib::nbostream serialized; if (value.getClass().inherits(TensorFieldValue::classId)) { const auto &tvalue = static_cast<const TensorFieldValue &>(value); - const std::unique_ptr<Tensor> &tensor = tvalue.getAsTensorPtr(); + auto tensor = tvalue.getAsTensorPtr(); if (tensor) { vespalib::tensor::TypedBinaryFormat::serialize(serialized, *tensor); } |