diff options
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 805012d224c..1b821a05c84 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -87,7 +87,9 @@ public: EXPECT_EQ(exp_levels, act_node.levels()); } void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { - auto rv = index->top_k_candidates(vectors.get_vector(docid), 3); + uint32_t k = 3; + auto qv = vectors.get_vector(docid); + auto rv = index->top_k_candidates(qv, k); size_t idx = 0; for (const auto & hit : rv) { // fprintf(stderr, "found docid %u dist %.1f\n", hit.docid, hit.distance); @@ -95,6 +97,12 @@ public: EXPECT_EQ(hit.docid, exp_hits[idx++]); } } + if (exp_hits.size() == k) { + std::vector<uint32_t> expected_by_docid = exp_hits; + std::sort(expected_by_docid.begin(), expected_by_docid.end()); + std::vector<uint32_t> got_by_docid = index->find_top_k(qv, k); + EXPECT_EQ(expected_by_docid, got_by_docid); + } } }; |