summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2024-02-02 15:33:36 +0000
committerGeir Storli <geirst@yahooinc.com>2024-02-02 15:33:36 +0000
commit2a051981b4f41c0d6e35f2c8d65ece7c47b994e3 (patch)
tree556260c5c7ea2c691b776114fe7600c53973b55f /searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp
parentfdff9a2b553d2c3c0c81aca3bd8bb6d9a491e443 (diff)
Tag hit estimates from attribute search contexts as unknown when applicable.
Diffstat (limited to 'searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp')
-rw-r--r--searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp36
1 files changed, 32 insertions, 4 deletions
diff --git a/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp b/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp
index 741a86b0beb..85b13c20f88 100644
--- a/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp
+++ b/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp
@@ -60,6 +60,7 @@ using largeint_t = AttributeVector::largeint_t;
using attribute::BasicType;
using attribute::CollectionType;
using attribute::Config;
+using attribute::HitEstimate;
using attribute::SearchContextParams;
using attribute::test::AttributeBuilder;
using fef::MatchData;
@@ -87,6 +88,12 @@ public:
DocSet::DocSet() noexcept = default;
DocSet::~DocSet() = default;
+bool is_flag_attribute(const Config& cfg) {
+ return cfg.fastSearch() &&
+ (cfg.basicType() == BasicType::INT8) &&
+ (cfg.collectionType() == CollectionType::ARRAY);
+}
+
template <typename V, typename T>
class PostingList
{
@@ -104,6 +111,21 @@ public:
DocSet & getHits() { return _hits; }
const DocSet & getHits() const { return _hits; }
uint32_t getHitCount() const { return _hits.size(); }
+ attribute::HitEstimate expected_hit_estimate() const {
+ if (getHitCount() == 0) {
+ return HitEstimate(0);
+ }
+ uint32_t docid_limit = _vec->getStatus().getNumDocs();
+ if (is_flag_attribute(_vec->getConfig())) {
+ return HitEstimate::unknown(docid_limit);
+ } else if (_vec->getConfig().fastSearch()) {
+ return HitEstimate(getHitCount());
+ } else if (_vec->getConfig().collectionType() == CollectionType::SINGLE) {
+ return HitEstimate::unknown(docid_limit);
+ } else {
+ return HitEstimate::unknown(std::max((uint64_t)docid_limit, _vec->getStatus().getNumValues()));
+ }
+ }
};
template <typename V, typename T>
@@ -166,7 +188,7 @@ private:
void testSearchIteratorConformance();
// test search functionality
template <typename V, typename T>
- void testFind(const PostingList<V, T> & first);
+ void testFind(const PostingList<V, T> & first, bool verify_hit_estimate);
template <typename V, typename T>
void testSearch(V & attribute, uint32_t numDocs, const std::vector<T> & values);
@@ -536,10 +558,16 @@ SearchContextTest::checkResultSet(const ResultSet & rs, const DocSet & expected,
//-----------------------------------------------------------------------------
template <typename V, typename T>
void
-SearchContextTest::testFind(const PostingList<V, T> & pl)
+SearchContextTest::testFind(const PostingList<V, T> & pl, bool verify_hit_estimate)
{
{ // strict search iterator
SearchContextPtr sc = getSearch(pl.getAttribute(), pl.getValue());
+ if (verify_hit_estimate) {
+ auto act_est = sc->calc_hit_estimate();
+ auto exp_est = pl.expected_hit_estimate();
+ EXPECT_EQUAL(exp_est.est_hits(), act_est.est_hits());
+ EXPECT_EQUAL(exp_est.is_unknown(), act_est.is_unknown());
+ }
sc->fetchPostings(queryeval::ExecuteInfo::TRUE);
TermFieldMatchData dummy;
SearchBasePtr sb = sc->createIterator(&dummy, true);
@@ -571,7 +599,7 @@ SearchContextTest::testSearch(V & attribute, uint32_t numDocs, const std::vector
// test find()
for (const auto & list : lists) {
- testFind(list);
+ testFind(list, true);
}
}
@@ -591,7 +619,7 @@ SearchContextTest::testMultiValueSearchHelper(V & vec, const std::vector<T> & va
for (const auto & list : lists) {
//std::cout << "testFind(lists[" << i << "]): value = " << lists[i].getValue()
// << ", hit count = " << lists[i].getHitCount() << std::endl;
- testFind(list);
+ testFind(list, false);
}
}