diff options
Diffstat (limited to 'searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 49 |
1 files changed, 43 insertions, 6 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index c6246bb8434..52f45860c1e 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -48,8 +48,7 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>; class HnswIndexTest : public ::testing::Test { public: FloatVectors vectors; - FloatSqEuclideanDistance distance_func; - LevelGenerator level_generator; + LevelGenerator* level_generator; HnswIndexUP index; HnswIndexTest() @@ -62,11 +61,14 @@ public: .set(7, {3, 5}).set(8, {0, 3}).set(9, {4, 5}); } void init(bool heuristic_select_neighbors) { - index = std::make_unique<HnswIndex>(vectors, distance_func, level_generator, + auto generator = std::make_unique<LevelGenerator>(); + level_generator = generator.get(); + index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(), + std::move(generator), HnswIndex::Config(2, 1, 10, heuristic_select_neighbors)); } void add_document(uint32_t docid, uint32_t max_level = 0) { - level_generator.level = max_level; + level_generator->level = max_level; index->add_document(docid); } void remove_document(uint32_t docid) { @@ -100,8 +102,10 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - std::vector<uint32_t> got_by_docid = index->find_top_k(k, qv, k); - EXPECT_EQ(expected_by_docid, got_by_docid); + auto got_by_docid = index->find_top_k(k, qv, k); + for (idx = 0; idx < k; ++idx) { + EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); + } } } }; @@ -262,5 +266,38 @@ TEST_F(HnswIndexTest, 2d_vectors_inserted_in_hierarchic_graph_with_heuristic_sel expect_level_0(7, {3, 6}); } +TEST_F(HnswIndexTest, manual_insert) +{ + init(false); + + std::vector<uint32_t> nbl; + HnswNode empty{nbl}; + index->set_node(1, empty); + index->set_node(2, empty); + + HnswNode three{{1,2}}; + index->set_node(3, three); + expect_level_0(1, {3}); + expect_level_0(2, {3}); + expect_level_0(3, {1,2}); + + expect_entry_point(1, 0); + + HnswNode twolevels{{{1},nbl}}; + index->set_node(4, twolevels); + + expect_entry_point(4, 1); + expect_level_0(1, {3,4}); + + HnswNode five{{{1,2}, {4}}}; + index->set_node(5, five); + + expect_levels(1, {{3,4,5}}); + expect_levels(2, {{3,5}}); + expect_levels(3, {{1,2}}); + expect_levels(4, {{1}, {5}}); + expect_levels(5, {{1,2}, {4}}); +} + GTEST_MAIN_RUN_ALL_TESTS() |