summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp15
1 files changed, 3 insertions, 12 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 193bb04843c..7877b488065 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -82,10 +82,6 @@ public:
~HnswIndexTest() {}
- const GlobalFilter *global_filter_ptr() const {
- return global_filter->is_active() ? global_filter.get() : nullptr;
- }
-
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
@@ -110,12 +106,7 @@ public:
}
void set_filter(std::vector<uint32_t> docids) {
uint32_t sz = 10;
- auto bit_vector = BitVector::create(sz);
- for (uint32_t id : docids) {
- EXPECT_LT(id, sz);
- bit_vector->setBit(id);
- }
- global_filter = GlobalFilter::create(std::move(bit_vector));
+ global_filter = GlobalFilter::create(docids, sz);
}
GenerationHandler::Guard take_read_guard() {
return gen_handler.takeGuard();
@@ -149,7 +140,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -170,7 +161,7 @@ public:
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid);
uint32_t k = 3;
- auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);