summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-01-08 09:38:00 +0000
committerArne Juul <arnej@verizonmedia.com>2021-01-08 10:56:14 +0000
commit290cc71ab1aa3b6c8eedd93d82351000873bb300 (patch)
tree6136e8ebcf4bb8305e4310764448e77c43620abc /searchlib/src
parent63824702076b738253dc14281ee11db89e976584 (diff)
make check_with_distance_threshold method
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp24
1 files changed, 16 insertions, 8 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index d081c299a43..20dc55df329 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -149,14 +149,22 @@ public:
EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid);
}
}
- if ((rv.size() > 1) && (rv[0].distance < rv[1].distance)) {
- double thr = (rv[0].distance + rv[1].distance) * 0.5;
- auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr);
- for (const auto & hit : got_by_docid) {
- printf("hit docid=%u dist=%g (thr %g)\n", hit.docid, hit.distance, thr);
- }
- EXPECT_EQ(got_by_docid.size(), 1);
- EXPECT_EQ(got_by_docid[0].docid, rv[0].docid);
+ check_with_distance_threshold(docid);
+ }
+ void check_with_distance_threshold(uint32_t docid) {
+ auto qv = vectors.get_vector(docid);
+ uint32_t k = 3;
+ auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek();
+ std::sort(rv.begin(), rv.end(), LesserDistance());
+ EXPECT_EQ(rv.size(), 3);
+ EXPECT_LE(rv[0].distance, rv[1].distance);
+ double thr = (rv[0].distance + rv[1].distance) * 0.5;
+ auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr);
+ EXPECT_EQ(got_by_docid.size(), 1);
+ EXPECT_EQ(got_by_docid[0].docid, rv[0].docid);
+ for (const auto & hit : got_by_docid) {
+ LOG(debug, "from docid=%u found docid=%u dist=%g (threshold %g)\n",
+ docid, hit.docid, hit.distance, thr);
}
}
};