diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-01-08 09:38:00 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-01-08 10:56:14 +0000 |
commit | 290cc71ab1aa3b6c8eedd93d82351000873bb300 (patch) | |
tree | 6136e8ebcf4bb8305e4310764448e77c43620abc /searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | |
parent | 63824702076b738253dc14281ee11db89e976584 (diff) |
make check_with_distance_threshold method
Diffstat (limited to 'searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index d081c299a43..20dc55df329 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -149,14 +149,22 @@ public: EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } } - if ((rv.size() > 1) && (rv[0].distance < rv[1].distance)) { - double thr = (rv[0].distance + rv[1].distance) * 0.5; - auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr); - for (const auto & hit : got_by_docid) { - printf("hit docid=%u dist=%g (thr %g)\n", hit.docid, hit.distance, thr); - } - EXPECT_EQ(got_by_docid.size(), 1); - EXPECT_EQ(got_by_docid[0].docid, rv[0].docid); + check_with_distance_threshold(docid); + } + void check_with_distance_threshold(uint32_t docid) { + auto qv = vectors.get_vector(docid); + uint32_t k = 3; + auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); + std::sort(rv.begin(), rv.end(), LesserDistance()); + EXPECT_EQ(rv.size(), 3); + EXPECT_LE(rv[0].distance, rv[1].distance); + double thr = (rv[0].distance + rv[1].distance) * 0.5; + auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr); + EXPECT_EQ(got_by_docid.size(), 1); + EXPECT_EQ(got_by_docid[0].docid, rv[0].docid); + for (const auto & hit : got_by_docid) { + LOG(debug, "from docid=%u found docid=%u dist=%g (threshold %g)\n", + docid, hit.docid, hit.distance, thr); } } }; |