From 2a051981b4f41c0d6e35f2c8d65ece7c47b994e3 Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Fri, 2 Feb 2024 15:33:36 +0000 Subject: Tag hit estimates from attribute search contexts as unknown when applicable. --- .../attribute/searchcontext/searchcontext_test.cpp | 36 +++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) (limited to 'searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp') diff --git a/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp b/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp index 741a86b0beb..85b13c20f88 100644 --- a/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp +++ b/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp @@ -60,6 +60,7 @@ using largeint_t = AttributeVector::largeint_t; using attribute::BasicType; using attribute::CollectionType; using attribute::Config; +using attribute::HitEstimate; using attribute::SearchContextParams; using attribute::test::AttributeBuilder; using fef::MatchData; @@ -87,6 +88,12 @@ public: DocSet::DocSet() noexcept = default; DocSet::~DocSet() = default; +bool is_flag_attribute(const Config& cfg) { + return cfg.fastSearch() && + (cfg.basicType() == BasicType::INT8) && + (cfg.collectionType() == CollectionType::ARRAY); +} + template class PostingList { @@ -104,6 +111,21 @@ public: DocSet & getHits() { return _hits; } const DocSet & getHits() const { return _hits; } uint32_t getHitCount() const { return _hits.size(); } + attribute::HitEstimate expected_hit_estimate() const { + if (getHitCount() == 0) { + return HitEstimate(0); + } + uint32_t docid_limit = _vec->getStatus().getNumDocs(); + if (is_flag_attribute(_vec->getConfig())) { + return HitEstimate::unknown(docid_limit); + } else if (_vec->getConfig().fastSearch()) { + return HitEstimate(getHitCount()); + } else if (_vec->getConfig().collectionType() == CollectionType::SINGLE) { + return HitEstimate::unknown(docid_limit); + } else { + return HitEstimate::unknown(std::max((uint64_t)docid_limit, _vec->getStatus().getNumValues())); + } + } }; template @@ -166,7 +188,7 @@ private: void testSearchIteratorConformance(); // test search functionality template - void testFind(const PostingList & first); + void testFind(const PostingList & first, bool verify_hit_estimate); template void testSearch(V & attribute, uint32_t numDocs, const std::vector & values); @@ -536,10 +558,16 @@ SearchContextTest::checkResultSet(const ResultSet & rs, const DocSet & expected, //----------------------------------------------------------------------------- template void -SearchContextTest::testFind(const PostingList & pl) +SearchContextTest::testFind(const PostingList & pl, bool verify_hit_estimate) { { // strict search iterator SearchContextPtr sc = getSearch(pl.getAttribute(), pl.getValue()); + if (verify_hit_estimate) { + auto act_est = sc->calc_hit_estimate(); + auto exp_est = pl.expected_hit_estimate(); + EXPECT_EQUAL(exp_est.est_hits(), act_est.est_hits()); + EXPECT_EQUAL(exp_est.is_unknown(), act_est.is_unknown()); + } sc->fetchPostings(queryeval::ExecuteInfo::TRUE); TermFieldMatchData dummy; SearchBasePtr sb = sc->createIterator(&dummy, true); @@ -571,7 +599,7 @@ SearchContextTest::testSearch(V & attribute, uint32_t numDocs, const std::vector // test find() for (const auto & list : lists) { - testFind(list); + testFind(list, true); } } @@ -591,7 +619,7 @@ SearchContextTest::testMultiValueSearchHelper(V & vec, const std::vector & va for (const auto & list : lists) { //std::cout << "testFind(lists[" << i << "]): value = " << lists[i].getValue() // << ", hit count = " << lists[i].getHitCount() << std::endl; - testFind(list); + testFind(list, false); } } -- cgit v1.2.3