diff options
Diffstat (limited to 'searchlib/src')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 76 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp | 88 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.h | 12 |
3 files changed, 159 insertions, 17 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 37c4d02017f..2516950b0cc 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -70,7 +70,7 @@ public: level_generator = generator.get(); index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(), std::move(generator), - HnswIndex::Config(2, 1, 10, heuristic_select_neighbors)); + HnswIndex::Config(5, 2, 10, heuristic_select_neighbors)); } void add_document(uint32_t docid, uint32_t max_level = 0) { level_generator->level = max_level; @@ -367,5 +367,79 @@ TEST_F(HnswIndexTest, memory_is_put_on_hold_while_read_guard_is_held) EXPECT_EQ(0, mem.allocatedBytesOnHold()); } +TEST_F(HnswIndexTest, shrink_called_simple) +{ + init(false); + std::vector<uint32_t> nbl; + HnswNode empty{nbl}; + index->set_node(1, empty); + nbl.push_back(1); + HnswNode nb1{nbl}; + index->set_node(2, nb1); + index->set_node(3, nb1); + index->set_node(4, nb1); + index->set_node(5, nb1); + expect_level_0(1, {2,3,4,5}); + index->set_node(6, nb1); + expect_level_0(1, {2,3,4,5,6}); + expect_level_0(2, {1}); + expect_level_0(3, {1}); + expect_level_0(4, {1}); + expect_level_0(5, {1}); + expect_level_0(6, {1}); + index->set_node(7, nb1); + expect_level_0(1, {2,3,4,6,7}); + expect_level_0(5, {}); + expect_level_0(6, {1}); + index->set_node(8, nb1); + expect_level_0(1, {2,3,4,7,8}); + expect_level_0(6, {}); + index->set_node(9, nb1); + expect_level_0(1, {2,3,4,7,8}); + expect_level_0(2, {1}); + expect_level_0(3, {1}); + expect_level_0(4, {1}); + expect_level_0(5, {}); + expect_level_0(6, {}); + expect_level_0(7, {1}); + expect_level_0(8, {1}); + expect_level_0(9, {}); + EXPECT_TRUE(index->check_link_symmetry()); +} + +TEST_F(HnswIndexTest, shrink_called_heuristic) +{ + init(true); + std::vector<uint32_t> nbl; + HnswNode empty{nbl}; + index->set_node(1, empty); + nbl.push_back(1); + HnswNode nb1{nbl}; + index->set_node(2, nb1); + index->set_node(3, nb1); + index->set_node(4, nb1); + index->set_node(5, nb1); + expect_level_0(1, {2,3,4,5}); + index->set_node(6, nb1); + expect_level_0(1, {2,3,4,5,6}); + expect_level_0(2, {1}); + expect_level_0(3, {1}); + expect_level_0(4, {1}); + expect_level_0(5, {1}); + expect_level_0(6, {1}); + index->set_node(7, nb1); + expect_level_0(1, {2,3,4}); + expect_level_0(2, {1}); + expect_level_0(3, {1}); + expect_level_0(4, {1}); + expect_level_0(5, {}); + expect_level_0(6, {}); + expect_level_0(7, {}); + index->set_node(8, nb1); + index->set_node(9, nb1); + expect_level_0(1, {2,3,4,8,9}); + EXPECT_TRUE(index->check_link_symmetry()); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 467e41d83fc..a39fdb856f2 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -21,6 +21,13 @@ constexpr float alloc_grow_factor = 0.2; constexpr size_t max_level_array_size = 16; constexpr size_t max_link_array_size = 64; +bool has_link_to(vespalib::ConstArrayRef<uint32_t> links, uint32_t id) { + for (uint32_t link : links) { + if (link == id) return true; + } + return false; +} + } search::datastore::ArrayStoreConfig @@ -106,22 +113,26 @@ HnswIndex::have_closer_distance(HnswCandidate candidate, const LinkArray& result return false; } -HnswIndex::LinkArray +HnswIndex::SelectResult HnswIndex::select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const { HnswCandidateVector sorted(neighbors); std::sort(sorted.begin(), sorted.end(), LesserDistance()); - LinkArray result; - for (size_t i = 0, m = std::min(static_cast<size_t>(max_links), sorted.size()); i < m; ++i) { - result.push_back(sorted[i].docid); + SelectResult result; + for (const auto & candidate : sorted) { + if (result.used.size() < max_links) { + result.used.push_back(candidate.docid); + } else { + result.unused.push_back(candidate.docid); + } } return result; } -HnswIndex::LinkArray +HnswIndex::SelectResult HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const { - LinkArray result; + SelectResult result; bool need_filtering = neighbors.size() > max_links; NearestPriQ nearest; for (const auto& entry : neighbors) { @@ -130,18 +141,23 @@ HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint while (!nearest.empty()) { auto candidate = nearest.top(); nearest.pop(); - if (need_filtering && have_closer_distance(candidate, result)) { + if (need_filtering && have_closer_distance(candidate, result.used)) { + result.unused.push_back(candidate.docid); continue; } - result.push_back(candidate.docid); - if (result.size() == max_links) { - return result; + result.used.push_back(candidate.docid); + if (result.used.size() == max_links) { + while (!nearest.empty()) { + candidate = nearest.top(); + nearest.pop(); + result.unused.push_back(candidate.docid); + } } } return result; } -HnswIndex::LinkArray +HnswIndex::SelectResult HnswIndex::select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const { if (_cfg.heuristic_select_neighbors()) { @@ -152,6 +168,25 @@ HnswIndex::select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_l } void +HnswIndex::shrink_if_needed(uint32_t docid, uint32_t level) +{ + auto old_links = get_link_array(docid, level); + uint32_t max_links = max_links_for_level(level); + if (old_links.size() > max_links) { + HnswCandidateVector neighbors; + for (uint32_t neighbor_docid : old_links) { + double dist = calc_distance(docid, neighbor_docid); + neighbors.emplace_back(neighbor_docid, dist); + } + auto split = select_neighbors(neighbors, max_links); + set_link_array(docid, level, split.used); + for (uint32_t removed_docid : split.unused) { + remove_link_to(removed_docid, docid, level); + } + } +} + +void HnswIndex::connect_new_node(uint32_t docid, const LinkArrayRef &neighbors, uint32_t level) { set_link_array(docid, level, neighbors); @@ -161,6 +196,9 @@ HnswIndex::connect_new_node(uint32_t docid, const LinkArrayRef &neighbors, uint3 new_links.push_back(docid); set_link_array(neighbor_docid, level, new_links); } + for (uint32_t neighbor_docid : neighbors) { + shrink_if_needed(neighbor_docid, level); + } } void @@ -287,8 +325,8 @@ HnswIndex::add_document(uint32_t docid) while (search_level >= 0) { // TODO: Rename to search_level? search_layer(input, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level); - auto neighbors = select_neighbors(best_neighbors.peek(), max_links_for_level(search_level)); - connect_new_node(docid, neighbors, search_level); + auto neighbors = select_neighbors(best_neighbors.peek(), _cfg.max_links_at_hierarchic_levels()); + connect_new_node(docid, neighbors.used, search_level); // TODO: Shrink neighbors if needed --search_level; } @@ -436,4 +474,28 @@ HnswIndex::set_node(uint32_t docid, const HnswNode &node) } } +bool +HnswIndex::check_link_symmetry() const +{ + bool all_sym = true; + for (size_t docid = 0; docid < _node_refs.size(); ++docid) { + auto node_ref = _node_refs[docid].load_acquire(); + if (node_ref.valid()) { + auto levels = _nodes.get(node_ref); + uint32_t level = 0; + for (const auto& links_ref : levels) { + auto links = _links.get(links_ref.load_acquire()); + for (auto neighbor_docid : links) { + auto neighbor_links = get_link_array(neighbor_docid, level); + if (! has_link_to(neighbor_links, docid)) { + all_sym = false; + } + } + ++level; + } + } + } + return all_sym; } + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 89c45d6b50c..b0c1cd1dcfd 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -108,9 +108,14 @@ protected: * Used by select_neighbors_heuristic(). */ bool have_closer_distance(HnswCandidate candidate, const LinkArray& curr_result) const; - LinkArray select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const; - LinkArray select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const; - LinkArray select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const; + struct SelectResult { + LinkArray used; + LinkArray unused; + }; + SelectResult select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const; + SelectResult select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const; + SelectResult select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const; + void shrink_if_needed(uint32_t docid, uint32_t level); void connect_new_node(uint32_t docid, const LinkArrayRef &neighbors, uint32_t level); void remove_link_to(uint32_t remove_from, uint32_t remove_id, uint32_t level); @@ -150,6 +155,7 @@ public: // Should only be used by unit tests. HnswNode get_node(uint32_t docid) const; void set_node(uint32_t docid, const HnswNode &node); + bool check_link_symmetry() const; }; } |