summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-02-25 11:41:02 +0000
committerArne Juul <arnej@verizonmedia.com>2020-02-26 11:33:31 +0000
commitb3eff781ae386d2f169dcde363ecf1f94e1397cd (patch)
tree1d803bd78ba36488ec087e0c9d0909de03522846 /searchlib/src
parent9432751c53285669b9a3b3ab605d673459bcbf37 (diff)
shrink links if needed
* select "M" links on all layers * change unit test parameters to avoid triggering shrink early * add symmetry validation method for use from unit test
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp76
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp88
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h12
3 files changed, 159 insertions, 17 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 37c4d02017f..2516950b0cc 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -70,7 +70,7 @@ public:
level_generator = generator.get();
index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
std::move(generator),
- HnswIndex::Config(2, 1, 10, heuristic_select_neighbors));
+ HnswIndex::Config(5, 2, 10, heuristic_select_neighbors));
}
void add_document(uint32_t docid, uint32_t max_level = 0) {
level_generator->level = max_level;
@@ -367,5 +367,79 @@ TEST_F(HnswIndexTest, memory_is_put_on_hold_while_read_guard_is_held)
EXPECT_EQ(0, mem.allocatedBytesOnHold());
}
+TEST_F(HnswIndexTest, shrink_called_simple)
+{
+ init(false);
+ std::vector<uint32_t> nbl;
+ HnswNode empty{nbl};
+ index->set_node(1, empty);
+ nbl.push_back(1);
+ HnswNode nb1{nbl};
+ index->set_node(2, nb1);
+ index->set_node(3, nb1);
+ index->set_node(4, nb1);
+ index->set_node(5, nb1);
+ expect_level_0(1, {2,3,4,5});
+ index->set_node(6, nb1);
+ expect_level_0(1, {2,3,4,5,6});
+ expect_level_0(2, {1});
+ expect_level_0(3, {1});
+ expect_level_0(4, {1});
+ expect_level_0(5, {1});
+ expect_level_0(6, {1});
+ index->set_node(7, nb1);
+ expect_level_0(1, {2,3,4,6,7});
+ expect_level_0(5, {});
+ expect_level_0(6, {1});
+ index->set_node(8, nb1);
+ expect_level_0(1, {2,3,4,7,8});
+ expect_level_0(6, {});
+ index->set_node(9, nb1);
+ expect_level_0(1, {2,3,4,7,8});
+ expect_level_0(2, {1});
+ expect_level_0(3, {1});
+ expect_level_0(4, {1});
+ expect_level_0(5, {});
+ expect_level_0(6, {});
+ expect_level_0(7, {1});
+ expect_level_0(8, {1});
+ expect_level_0(9, {});
+ EXPECT_TRUE(index->check_link_symmetry());
+}
+
+TEST_F(HnswIndexTest, shrink_called_heuristic)
+{
+ init(true);
+ std::vector<uint32_t> nbl;
+ HnswNode empty{nbl};
+ index->set_node(1, empty);
+ nbl.push_back(1);
+ HnswNode nb1{nbl};
+ index->set_node(2, nb1);
+ index->set_node(3, nb1);
+ index->set_node(4, nb1);
+ index->set_node(5, nb1);
+ expect_level_0(1, {2,3,4,5});
+ index->set_node(6, nb1);
+ expect_level_0(1, {2,3,4,5,6});
+ expect_level_0(2, {1});
+ expect_level_0(3, {1});
+ expect_level_0(4, {1});
+ expect_level_0(5, {1});
+ expect_level_0(6, {1});
+ index->set_node(7, nb1);
+ expect_level_0(1, {2,3,4});
+ expect_level_0(2, {1});
+ expect_level_0(3, {1});
+ expect_level_0(4, {1});
+ expect_level_0(5, {});
+ expect_level_0(6, {});
+ expect_level_0(7, {});
+ index->set_node(8, nb1);
+ index->set_node(9, nb1);
+ expect_level_0(1, {2,3,4,8,9});
+ EXPECT_TRUE(index->check_link_symmetry());
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 467e41d83fc..a39fdb856f2 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -21,6 +21,13 @@ constexpr float alloc_grow_factor = 0.2;
constexpr size_t max_level_array_size = 16;
constexpr size_t max_link_array_size = 64;
+bool has_link_to(vespalib::ConstArrayRef<uint32_t> links, uint32_t id) {
+ for (uint32_t link : links) {
+ if (link == id) return true;
+ }
+ return false;
+}
+
}
search::datastore::ArrayStoreConfig
@@ -106,22 +113,26 @@ HnswIndex::have_closer_distance(HnswCandidate candidate, const LinkArray& result
return false;
}
-HnswIndex::LinkArray
+HnswIndex::SelectResult
HnswIndex::select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const
{
HnswCandidateVector sorted(neighbors);
std::sort(sorted.begin(), sorted.end(), LesserDistance());
- LinkArray result;
- for (size_t i = 0, m = std::min(static_cast<size_t>(max_links), sorted.size()); i < m; ++i) {
- result.push_back(sorted[i].docid);
+ SelectResult result;
+ for (const auto & candidate : sorted) {
+ if (result.used.size() < max_links) {
+ result.used.push_back(candidate.docid);
+ } else {
+ result.unused.push_back(candidate.docid);
+ }
}
return result;
}
-HnswIndex::LinkArray
+HnswIndex::SelectResult
HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const
{
- LinkArray result;
+ SelectResult result;
bool need_filtering = neighbors.size() > max_links;
NearestPriQ nearest;
for (const auto& entry : neighbors) {
@@ -130,18 +141,23 @@ HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint
while (!nearest.empty()) {
auto candidate = nearest.top();
nearest.pop();
- if (need_filtering && have_closer_distance(candidate, result)) {
+ if (need_filtering && have_closer_distance(candidate, result.used)) {
+ result.unused.push_back(candidate.docid);
continue;
}
- result.push_back(candidate.docid);
- if (result.size() == max_links) {
- return result;
+ result.used.push_back(candidate.docid);
+ if (result.used.size() == max_links) {
+ while (!nearest.empty()) {
+ candidate = nearest.top();
+ nearest.pop();
+ result.unused.push_back(candidate.docid);
+ }
}
}
return result;
}
-HnswIndex::LinkArray
+HnswIndex::SelectResult
HnswIndex::select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const
{
if (_cfg.heuristic_select_neighbors()) {
@@ -152,6 +168,25 @@ HnswIndex::select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_l
}
void
+HnswIndex::shrink_if_needed(uint32_t docid, uint32_t level)
+{
+ auto old_links = get_link_array(docid, level);
+ uint32_t max_links = max_links_for_level(level);
+ if (old_links.size() > max_links) {
+ HnswCandidateVector neighbors;
+ for (uint32_t neighbor_docid : old_links) {
+ double dist = calc_distance(docid, neighbor_docid);
+ neighbors.emplace_back(neighbor_docid, dist);
+ }
+ auto split = select_neighbors(neighbors, max_links);
+ set_link_array(docid, level, split.used);
+ for (uint32_t removed_docid : split.unused) {
+ remove_link_to(removed_docid, docid, level);
+ }
+ }
+}
+
+void
HnswIndex::connect_new_node(uint32_t docid, const LinkArrayRef &neighbors, uint32_t level)
{
set_link_array(docid, level, neighbors);
@@ -161,6 +196,9 @@ HnswIndex::connect_new_node(uint32_t docid, const LinkArrayRef &neighbors, uint3
new_links.push_back(docid);
set_link_array(neighbor_docid, level, new_links);
}
+ for (uint32_t neighbor_docid : neighbors) {
+ shrink_if_needed(neighbor_docid, level);
+ }
}
void
@@ -287,8 +325,8 @@ HnswIndex::add_document(uint32_t docid)
while (search_level >= 0) {
// TODO: Rename to search_level?
search_layer(input, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level);
- auto neighbors = select_neighbors(best_neighbors.peek(), max_links_for_level(search_level));
- connect_new_node(docid, neighbors, search_level);
+ auto neighbors = select_neighbors(best_neighbors.peek(), _cfg.max_links_at_hierarchic_levels());
+ connect_new_node(docid, neighbors.used, search_level);
// TODO: Shrink neighbors if needed
--search_level;
}
@@ -436,4 +474,28 @@ HnswIndex::set_node(uint32_t docid, const HnswNode &node)
}
}
+bool
+HnswIndex::check_link_symmetry() const
+{
+ bool all_sym = true;
+ for (size_t docid = 0; docid < _node_refs.size(); ++docid) {
+ auto node_ref = _node_refs[docid].load_acquire();
+ if (node_ref.valid()) {
+ auto levels = _nodes.get(node_ref);
+ uint32_t level = 0;
+ for (const auto& links_ref : levels) {
+ auto links = _links.get(links_ref.load_acquire());
+ for (auto neighbor_docid : links) {
+ auto neighbor_links = get_link_array(neighbor_docid, level);
+ if (! has_link_to(neighbor_links, docid)) {
+ all_sym = false;
+ }
+ }
+ ++level;
+ }
+ }
+ }
+ return all_sym;
}
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 89c45d6b50c..b0c1cd1dcfd 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -108,9 +108,14 @@ protected:
* Used by select_neighbors_heuristic().
*/
bool have_closer_distance(HnswCandidate candidate, const LinkArray& curr_result) const;
- LinkArray select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const;
- LinkArray select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const;
- LinkArray select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const;
+ struct SelectResult {
+ LinkArray used;
+ LinkArray unused;
+ };
+ SelectResult select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const;
+ SelectResult select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const;
+ SelectResult select_neighbors(const HnswCandidateVector& neighbors, uint32_t max_links) const;
+ void shrink_if_needed(uint32_t docid, uint32_t level);
void connect_new_node(uint32_t docid, const LinkArrayRef &neighbors, uint32_t level);
void remove_link_to(uint32_t remove_from, uint32_t remove_id, uint32_t level);
@@ -150,6 +155,7 @@ public:
// Should only be used by unit tests.
HnswNode get_node(uint32_t docid) const;
void set_node(uint32_t docid, const HnswNode &node);
+ bool check_link_symmetry() const;
};
}