diff options
author | Geir Storli <geirst@oath.com> | 2018-04-19 11:53:50 +0000 |
---|---|---|
committer | Geir Storli <geirst@oath.com> | 2018-04-19 11:56:30 +0000 |
commit | 33ef61f554a5768c4457e78207c689fbf5661220 (patch) | |
tree | b33ee25db5ff23fb800badd1fc47bd78309b1b55 | |
parent | 939a0c130bbcea1cc5fb743711c7a021f467fcae (diff) |
Add type-safe down-cast to ITensorAttribute in IAttributeVector and use this instead of dynamic_cast.
14 files changed, 49 insertions, 14 deletions
diff --git a/searchcommon/src/vespa/searchcommon/attribute/iattributevector.h b/searchcommon/src/vespa/searchcommon/attribute/iattributevector.h index f8ab03fdabb..26322e78480 100644 --- a/searchcommon/src/vespa/searchcommon/attribute/iattributevector.h +++ b/searchcommon/src/vespa/searchcommon/attribute/iattributevector.h @@ -12,6 +12,10 @@ namespace search { class IDocumentWeightAttribute; class QueryTermSimple; +namespace tensor { +class ITensorAttribute; +} + namespace attribute { class ISearchContext; @@ -273,6 +277,13 @@ public: virtual const IDocumentWeightAttribute *asDocumentWeightAttribute() const = 0; /** + * Type-safe down-cast to a tensor attribute. + * + * @return tensor attribute or nullptr if not supported. + */ + virtual const tensor::ITensorAttribute *asTensorAttribute() const = 0; + + /** * Returns the basic type of this attribute vector. * * @return basic type diff --git a/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp b/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp index 90b63138fde..b7b6eac94ed 100644 --- a/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp +++ b/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp @@ -141,6 +141,10 @@ TEST_F("asDocumentWeightAttribute() returns nullptr", Fixture) { EXPECT_TRUE(f.get_imported_attr()->asDocumentWeightAttribute() == nullptr); } +TEST_F("asTensorAttribute() returns nullptr", Fixture) { + EXPECT_TRUE(f.get_imported_attr()->asTensorAttribute() == nullptr); +} + TEST_F("Multi-valued integer attribute values can be retrieved via reference", Fixture) { const std::vector<int64_t> doc3_values({1234}); const std::vector<int64_t> doc7_values({5678, 9876, 555, 777}); @@ -510,8 +514,9 @@ struct TensorAttrFixture : Fixture { } Tensor::UP getTensor(DocId docId) { auto imp_attr = this->get_imported_attr(); - const ITensorAttribute & tensorAttr = dynamic_cast<const ITensorAttribute &>(*imp_attr); - return tensorAttr.getTensor(docId); + const ITensorAttribute *tensorAttr = imp_attr->asTensorAttribute(); + ASSERT_TRUE(tensorAttr != nullptr); + return tensorAttr->getTensor(docId); } void assertNoTensor(DocId docId) { auto tensor = getTensor(docId); diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp index b3a25207e00..e94eaa3f542 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp @@ -699,6 +699,7 @@ AttributeVector::enableEnumeratedSave(bool enable) { attribute::IPostingListAttributeBase *AttributeVector::getIPostingListAttributeBase() { return nullptr; } const attribute::IPostingListAttributeBase *AttributeVector::getIPostingListAttributeBase() const { return nullptr; } const IDocumentWeightAttribute * AttributeVector::asDocumentWeightAttribute() const { return nullptr; } +const tensor::ITensorAttribute *AttributeVector::asTensorAttribute() const { return nullptr; } bool AttributeVector::hasPostings() { return getIPostingListAttributeBase() != nullptr; } uint64_t AttributeVector::getUniqueValueCount() const { return getTotalValueCount(); } uint64_t AttributeVector::getTotalValueCount() const { return getNumDocs(); } diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.h b/searchlib/src/vespa/searchlib/attribute/attributevector.h index efec7d4edb3..5917972b4fe 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.h +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.h @@ -514,6 +514,8 @@ public: // type-safe down-cast to attribute supporting direct document weight iterators const IDocumentWeightAttribute *asDocumentWeightAttribute() const override; + const tensor::ITensorAttribute *asTensorAttribute() const override; + /** - Search for equality - Range search diff --git a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp index da404d6f904..0a752980a19 100644 --- a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp +++ b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp @@ -113,6 +113,10 @@ const IDocumentWeightAttribute *ImportedAttributeVectorReadGuard::asDocumentWeig return nullptr; } +const tensor::ITensorAttribute *ImportedAttributeVectorReadGuard::asTensorAttribute() const { + return nullptr; +} + BasicType::Type ImportedAttributeVectorReadGuard::getBasicType() const { return _target_attribute.getBasicType(); } diff --git a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h index 96f46fa684c..92d4af518f2 100644 --- a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h +++ b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h @@ -65,6 +65,7 @@ public: virtual std::unique_ptr<ISearchContext> createSearchContext(std::unique_ptr<QueryTermSimple> term, const SearchContextParams ¶ms) const override; virtual const IDocumentWeightAttribute *asDocumentWeightAttribute() const override; + virtual const tensor::ITensorAttribute *asTensorAttribute() const override; virtual BasicType::Type getBasicType() const override; virtual size_t getFixedWidth() const override; virtual CollectionType::Type getCollectionType() const override; diff --git a/searchlib/src/vespa/searchlib/features/attributefeature.cpp b/searchlib/src/vespa/searchlib/features/attributefeature.cpp index cd7a27b37b5..b400233bd99 100644 --- a/searchlib/src/vespa/searchlib/features/attributefeature.cpp +++ b/searchlib/src/vespa/searchlib/features/attributefeature.cpp @@ -391,7 +391,7 @@ createTensorAttributeExecutor(const IAttributeVector *attribute, const vespalib: " Returning empty tensor.", attribute->getName().c_str()); return ConstantTensorExecutor::createEmpty(tensorType, stash); } - const ITensorAttribute *tensorAttribute = dynamic_cast<const ITensorAttribute *>(attribute); + const ITensorAttribute *tensorAttribute = attribute->asTensorAttribute(); if (tensorAttribute == nullptr) { LOG(warning, "The attribute vector '%s' could not be converted to a tensor attribute." " Returning empty tensor.", attribute->getName().c_str()); diff --git a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.cpp b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.cpp index c84f61ccb70..71aca30ca5e 100644 --- a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.cpp +++ b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.cpp @@ -16,8 +16,7 @@ ImportedTensorAttributeVector::ImportedTensorAttributeVector(vespalib::stringref : ImportedAttributeVector(name, std::move(reference_attribute), std::move(target_attribute), std::move(document_meta_store), - use_search_cache), - _target_tensor_attribute(dynamic_cast<const ITensorAttribute &>(*_target_attribute)) + use_search_cache) { } @@ -29,8 +28,7 @@ ImportedTensorAttributeVector::ImportedTensorAttributeVector(vespalib::stringref : ImportedAttributeVector(name, std::move(reference_attribute), std::move(target_attribute), std::move(document_meta_store), - std::move(search_cache)), - _target_tensor_attribute(dynamic_cast<const ITensorAttribute &>(*_target_attribute)) + std::move(search_cache)) { } diff --git a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.h b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.h index be946066363..b9f643b179f 100644 --- a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.h +++ b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector.h @@ -15,7 +15,7 @@ class ImportedTensorAttributeVector : public attribute::ImportedAttributeVector { using ReferenceAttribute = attribute::ReferenceAttribute; using BitVectorSearchCache = attribute::BitVectorSearchCache; - const ITensorAttribute &_target_tensor_attribute; + public: ImportedTensorAttributeVector(vespalib::stringref name, std::shared_ptr<ReferenceAttribute> reference_attribute, diff --git a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.cpp b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.cpp index a927f4fac6d..9ea598d3578 100644 --- a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.cpp +++ b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.cpp @@ -11,7 +11,7 @@ ImportedTensorAttributeVectorReadGuard::ImportedTensorAttributeVectorReadGuard(c bool stableEnumGuard) : ImportedAttributeVectorReadGuard(imported_attribute, stableEnumGuard), - _target_tensor_attribute(dynamic_cast<const ITensorAttribute &>(*imported_attribute.getTargetAttribute())) + _target_tensor_attribute(*imported_attribute.getTargetAttribute()->asTensorAttribute()) { } @@ -19,6 +19,12 @@ ImportedTensorAttributeVectorReadGuard::~ImportedTensorAttributeVectorReadGuard( { } +const ITensorAttribute * +ImportedTensorAttributeVectorReadGuard::asTensorAttribute() const +{ + return this; +} + std::unique_ptr<Tensor> ImportedTensorAttributeVectorReadGuard::getTensor(uint32_t docId) const { diff --git a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h index 2715e117927..d538135bb99 100644 --- a/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h +++ b/searchlib/src/vespa/searchlib/tensor/imported_tensor_attribute_vector_read_guard.h @@ -28,10 +28,12 @@ public: bool stableEnumGuard); ~ImportedTensorAttributeVectorReadGuard(); + const ITensorAttribute *asTensorAttribute() const override; + virtual std::unique_ptr<Tensor> getTensor(uint32_t docId) const override; virtual std::unique_ptr<Tensor> getEmptyTensor() const override; virtual void getTensor(uint32_t docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; virtual vespalib::eval::ValueType getTensorType() const override; }; -} // namespace search::tensor +} diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index 0ca0d2fa776..2c005fb3277 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -60,7 +60,11 @@ TensorAttribute::~TensorAttribute() { } - +const ITensorAttribute * +TensorAttribute::asTensorAttribute() const +{ + return this; +} uint32_t TensorAttribute::clearDoc(DocId docId) diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h index 35dbb3ab21b..c5a76014485 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.h @@ -35,6 +35,8 @@ public: TensorAttribute(const vespalib::stringref &baseFileName, const Config &cfg, TensorStore &tensorStore); virtual ~TensorAttribute(); + virtual const ITensorAttribute *asTensorAttribute() const override; + virtual uint32_t clearDoc(DocId docId) override; virtual void onCommit() override; virtual void onUpdateStat() override; diff --git a/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp b/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp index a15e0e0e0c0..5e8e4ee0584 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/attributedfw.cpp @@ -136,9 +136,8 @@ SingleAttrDFW::insertField(uint32_t docid, BasicType::Type t = v.getBasicType(); switch (t) { case BasicType::TENSOR: { - const tensor::ITensorAttribute &tv = - dynamic_cast<const tensor::ITensorAttribute &>(v); - const auto tensor = tv.getTensor(docid); + const tensor::ITensorAttribute *tv = v.asTensorAttribute(); + const auto tensor = tv->getTensor(docid); if (tensor) { vespalib::nbostream str; vespalib::tensor::TypedBinaryFormat::serialize(str, *tensor); |