diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-02-24 09:42:02 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-02-24 09:57:59 +0000 |
commit | ffa2293de302d99051f7fc97d29c4dc606f045f1 (patch) | |
tree | 95228696bf26ebf152c487c5d2fe189cd8dae078 /eval/src | |
parent | 00813e6561cae0365aad710d30a9bc0647e6a01f (diff) |
experimental HNSW with various extensions
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/tests/ann/extended-hnsw.cpp | 830 |
1 files changed, 830 insertions, 0 deletions
diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp new file mode 100644 index 00000000000..42f3a10b389 --- /dev/null +++ b/eval/src/tests/ann/extended-hnsw.cpp @@ -0,0 +1,830 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <algorithm> +#include <assert.h> +#include <queue> +#include <cinttypes> +#include "std-random.h" +#include "nns.h" + +/* + Todo: + + measure effect of: + 1) removing leftover backlinks during "shrink" operation + 2) refilling to low-watermark after 1) happens + 3) refilling to mid-watermark after 1) happens + 4) adding then removing 20% extra documents + 5) removing 20% first-added documents + 6) removing first-added documents while inserting new ones + + 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100 + 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100 + 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids + + 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k) + */ + +static size_t distcalls_simple; +static size_t distcalls_search_layer; +static size_t distcalls_other; +static size_t distcalls_heuristic; +static size_t distcalls_shrink; +static size_t distcalls_refill; +static size_t refill_needed_calls; +static size_t shrink_needed_calls; +static size_t disconnected_weak_links; +static size_t disconnected_for_symmetry; +static size_t select_n_full; +static size_t select_n_partial; + +struct LinkList : std::vector<uint32_t> +{ + bool has_link_to(uint32_t id) const { + auto iter = std::find(begin(), end(), id); + return (iter != end()); + } + void remove_link(uint32_t id) { + uint32_t last = back(); + for (iterator iter = begin(); iter != end(); ++iter) { + if (*iter == id) { + *iter = last; + pop_back(); + return; + } + } + fprintf(stderr, "BAD missing link to remove: %u\n", id); + abort(); + } +}; + +struct Node { + std::vector<LinkList> _links; + Node(uint32_t , uint32_t numLevels, uint32_t M) + : _links(numLevels) + { + for (uint32_t i = 0; i < _links.size(); ++i) { + _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); + } + } +}; + +struct VisitedSet +{ + using Mark = unsigned short; + Mark *ptr; + Mark curval; + size_t sz; + VisitedSet(const VisitedSet &) = delete; + VisitedSet& operator=(const VisitedSet &) = delete; + explicit VisitedSet(size_t size) { + ptr = (Mark *)malloc(size * sizeof(Mark)); + curval = -1; + sz = size; + clear(); + } + void clear() { + ++curval; + if (curval == 0) { + memset(ptr, 0, sz * sizeof(Mark)); + ++curval; + } + } + ~VisitedSet() { free(ptr); } + void mark(size_t id) { ptr[id] = curval; } + bool isMarked(size_t id) const { return ptr[id] == curval; } +}; + +struct VisitedSetPool +{ + std::unique_ptr<VisitedSet> lastUsed; + VisitedSetPool() { + lastUsed = std::make_unique<VisitedSet>(250); + } + ~VisitedSetPool() {} + VisitedSet &get(size_t size) { + if (size > lastUsed->sz) { + lastUsed = std::make_unique<VisitedSet>(size*2); + } else { + lastUsed->clear(); + } + return *lastUsed; + } +}; + +struct HnswHit { + double dist; + uint32_t docid; + HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} +}; + +struct GreaterDist { + bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { + return (rhs.dist < lhs.dist); + } +}; +struct LesserDist { + bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { + return (lhs.dist < rhs.dist); + } +}; + +using NearestList = std::vector<HnswHit>; + +struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist> +{ +}; + +struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist> +{ + NearestList steal() { + NearestList result; + c.swap(result); + return result; + } + const NearestList& peek() const { return c; } +}; + +class HnswLikeNns : public NNS<float> +{ +private: + std::vector<Node> _nodes; + uint32_t _entryId; + int _entryLevel; + uint32_t _M; + uint32_t _efConstruction; + double _levelMultiplier; + RndGen _rndGen; + VisitedSetPool _visitedSetPool; + size_t _ops_counter; + + double distance(Vector v, uint32_t id) const; + + double distance(uint32_t a, uint32_t b) const { + Vector v = _dva.get(a); + return distance(v, b); + } + + int randomLevel() { + double unif = _rndGen.nextUniform(); + double r = -log(1.0-unif) * _levelMultiplier; + return (int) r; + } + + uint32_t count_reachable() const; + void dumpStats() const; + +public: + HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(200), + _levelMultiplier(1.0 / log(1.0 * _M)), + _rndGen(), + _ops_counter(0) + { + } + + ~HnswLikeNns() { dumpStats(); } + + LinkList& getLinkList(uint32_t docid, uint32_t level) { + // assert(docid < _nodes.size()); + // assert(level < _nodes[docid]._links.size()); + return _nodes[docid]._links[level]; + } + + const LinkList& getLinkList(uint32_t docid, uint32_t level) const { + return _nodes[docid]._links[level]; + } + + // simple greedy search + HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + ++distcalls_simple; + if (dist < curPoint.dist) { + curPoint = HnswHit(n_id, SqDist(dist)); + keepGoing = true; + } + } + } + return curPoint; + } + + void search_layer(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel); + + void search_layer_with_filter(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist); + + bool haveCloserDistance(HnswHit e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + ++distcalls_heuristic; + if (dist < e.dist) return true; + } + return false; + } + + LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; + + LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; + + void addDoc(uint32_t docid) override { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + track_ops(); + return; + } + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); +#undef MULTI_ENTRY_I +#ifdef MULTI_ENTRY_I + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > level) { + search_layer(vector, w, visited, 5 * _M, searchLevel); + --searchLevel; + } +#else + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); +#endif + searchLevel = std::min(level, _entryLevel); + while (searchLevel >= 0) { + search_layer(vector, w, visited, _efConstruction, searchLevel); + LinkList neighbors = select_neighbors(w.peek(), _M); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; + } + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; + } + track_ops(); + } + + void track_ops() { + _ops_counter++; + if ((_ops_counter % 10000) == 0) { + double div = _ops_counter; + fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); + fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); + } + } + + void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { + LinkList &links = getLinkList(from_id, level); + links.remove_link(remove_id); + } + +#undef SIMPLE_REFILL +#ifdef SIMPLE_REFILL + void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() * 2 < _M) { + const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + ++refill_needed_calls; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= maxLinks) continue; + other_links.push_back(my_id); + my_links.push_back(repl_id); + if (my_links.size() >= _M) return; + } + } + } +#else + void refill_all(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + const LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= maxLinks) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + if (my_links.size() == _M) break; + } + } + void refill_one(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= _M) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + return; + } + } + void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() < _M) { + ++refill_needed_calls; + refill_all(my_id, replacements, level); + } + } +#endif + + void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); + + void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { + LinkList &links = getLinkList(shrink_id, level); + NearestList distances; + for (uint32_t n_id : links) { + double n_dist = distance(shrink_id, n_id); + ++distcalls_shrink; + distances.emplace_back(n_id, SqDist(n_dist)); + } + LinkList lostLinks; + LinkList oldLinks = links; + links = remove_weakest(distances, maxLinks, lostLinks); +#define KEEP_SYM +#ifdef KEEP_SYM + for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; + remove_link_from(lost_id, shrink_id, level); + } +#define DO_REFILL_AFTER_KEEP_SYM +#ifdef DO_REFILL_AFTER_KEEP_SYM + for (uint32_t lost_id : lostLinks) { +#ifdef SIMPLE_REFILL + refill_ifneeded(lost_id, oldLinks, level); +#else + refill_all(lost_id, oldLinks, level); +#endif + } +#endif +#endif + } + + void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); + + void mutually_reconnect(LinkList cluster, int level) { + while (! cluster.empty()) { + uint32_t n_id = cluster.back(); + cluster.pop_back(); +#ifdef SIMPLE_REFILL + refill_ifneeded(n_id, cluster, level); +#else + refill_all(n_id, cluster, level); +#endif + } + } + + void removeDoc(uint32_t docid) override { + Node &node = _nodes[docid]; + bool need_new_entrypoint = (docid == _entryId); + for (int level = node._links.size(); level-- > 0; ) { + LinkList my_links; + my_links.swap(node._links[level]); + for (uint32_t n_id : my_links) { + if (need_new_entrypoint) { + _entryId = n_id; + _entryLevel = level; + need_new_entrypoint = false; + } + remove_link_from(n_id, docid, level); + } + mutually_reconnect(my_links, level); + } + node = Node(docid, 0, _M); + if (need_new_entrypoint) { + _entryLevel = -1; + _entryId = 0; + for (uint32_t i = 0; i < _nodes.size(); ++i) { + if (_nodes[i]._links.size() > 0) { + _entryId = i; + _entryLevel = _nodes[i]._links.size() - 1; + break; + } + } + } + track_ops(); + } + + std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override { + std::vector<NnsHit> result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); +#undef MULTI_ENTRY_S +#ifdef MULTI_ENTRY_S + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > 0) { + search_layer(vector, w, visited, std::min(k, search_k), searchLevel); + --searchLevel; + } +#else + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); +#endif + search_layer(vector, w, visited, std::max(k, search_k), 0); + while (w.size() > k) { + w.pop(); + } + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(tmp.size()); + for (const auto & hit : tmp) { + result.emplace_back(hit.docid, SqDist(hit.dist)); + } + return result; + } + + std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; +}; + + +double +HnswLikeNns::distance(Vector v, uint32_t b) const +{ + Vector w = _dva.get(b); + return l2distCalc.l2sq_dist(v, w); +} + +std::vector<NnsHit> +HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + std::vector<NnsHit> result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); +#ifdef MULTI_ENTRY_S + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > 0) { + search_layer(vector, w, visited, std::min(k, search_k), searchLevel); + --searchLevel; + } +#else + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); +#endif + search_layer_with_filter(vector, w, visited, std::max(k, search_k), 0, blacklist); + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(std::min((size_t)k, tmp.size())); + for (const auto & hit : tmp) { + if (blacklist.isSet(hit.docid)) continue; + result.emplace_back(hit.docid, SqDist(hit.dist)); + if (result.size() == k) break; + } + return result; +} + +void +HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) { + uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + for (uint32_t old_id : neighbors) { + LinkList &oldLinks = getLinkList(old_id, level); + if (oldLinks.size() > maxLinks) { + ++shrink_needed_calls; + shrink_links(old_id, maxLinks, level); + } + } +} + +void +HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel) +{ + NearestPriQ candidates; + + for (const HnswHit & entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + } + double limd = std::numeric_limits<double>::max(); + while (! candidates.empty()) { + HnswHit cand = candidates.top(); + if (cand.dist > limd) { + break; + } + candidates.pop(); + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + ++distcalls_search_layer; + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } + return; +} + +void +HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist) +{ + NearestPriQ candidates; + + for (const HnswHit & entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + if (blacklist.isSet(entry.docid)) ++ef; + } + double limd = std::numeric_limits<double>::max(); + while (! candidates.empty()) { + HnswHit cand = candidates.top(); + if (cand.dist > limd) { + break; + } + candidates.pop(); + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + ++distcalls_search_layer; + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + if (blacklist.isSet(e_id)) continue; + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } +} + +LinkList +HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const +{ + LinkList result; + result.reserve(curMax+1); + NearestPriQ w; + for (const auto & entry : neighbors) { + w.push(entry); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (result.size() == curMax || haveCloserDistance(e, result)) { + lost.push_back(e.docid); + } else { + result.push_back(e.docid); + } + } + return result; +} + +#define NO_BACKFILL +#ifdef NO_BACKFILL +LinkList +HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const +{ + LinkList result; + result.reserve(curMax+1); + NearestPriQ w; + for (const auto & entry : neighbors) { + w.push(entry); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, result)) { + continue; + } + result.push_back(e.docid); + if (result.size() == curMax) { + ++select_n_full; + return result; + } + } + ++select_n_partial; + return result; +} +#else +LinkList +HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const +{ + LinkList result; + result.reserve(curMax+1); + bool needFiltering = (neighbors.size() > curMax); + NearestPriQ w; + for (const auto & entry : neighbors) { + w.push(entry); + } + LinkList backfill; + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (needFiltering && haveCloserDistance(e, result)) { + backfill.push_back(e.docid); + continue; + } + result.push_back(e.docid); + if (result.size() == curMax) return result; + } + if (result.size() * 4 < _M) { + for (uint32_t fill_id : backfill) { + result.push_back(fill_id); + if (result.size() * 2 >= _M) break; + } + } + return result; +} +#endif + +void +HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level) { + LinkList &newLinks = getLinkList(id, level); + for (uint32_t neigh_id : neighbors) { + LinkList &oldLinks = getLinkList(neigh_id, level); + newLinks.push_back(neigh_id); + oldLinks.push_back(id); + } +#define DISCONNECT_OLD_WEAK_LINKS +#ifdef DISCONNECT_OLD_WEAK_LINKS + for (uint32_t i = 1; i < neighbors.size(); ++i) { + uint32_t n_1 = neighbors[i]; + LinkList &links_1 = getLinkList(n_1, level); + for (uint32_t j = 0; j < i; ++j) { + uint32_t n_2 = neighbors[j]; + if (links_1.has_link_to(n_2)) { + ++disconnected_weak_links; + LinkList &links_2 = getLinkList(n_2, level); + links_1.remove_link(n_2); + links_2.remove_link(n_1); + } + } + } +#endif +} + +uint32_t +HnswLikeNns::count_reachable() const { + VisitedSet visited(_nodes.size()); + int level = _entryLevel; + LinkList curList; + curList.push_back(_entryId); + visited.mark(_entryId); + uint32_t idx = 0; + while (level >= 0) { + while (idx < curList.size()) { + uint32_t id = curList[idx++]; + const LinkList &links = getLinkList(id, level); + for (uint32_t n_id : links) { + if (visited.isMarked(n_id)) continue; + visited.mark(n_id); + curList.push_back(n_id); + } + } + --level; + idx = 0; + } + return curList.size(); +} + +void +HnswLikeNns::dumpStats() const { + std::vector<uint32_t> levelCounts; + levelCounts.resize(_entryLevel + 2); + std::vector<uint32_t> outLinkHist; + outLinkHist.resize(2 * _M + 2); + uint32_t symmetrics = 0; + uint32_t level1links = 0; + uint32_t both_l_links = 0; + fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n", + _nodes.size(), _entryLevel, _entryId); + + for (uint32_t id = 0; id < _nodes.size(); ++id) { + const auto &node = _nodes[id]; + uint32_t levels = node._links.size(); + levelCounts[levels]++; + if (levels < 1) { + outLinkHist[0]++; + continue; + } + const LinkList &link_list = getLinkList(id, 0); + uint32_t numlinks = link_list.size(); + outLinkHist[numlinks]++; + if (numlinks < 1) { + fprintf(stderr, "node with %u links: id %u\n", numlinks, id); + } + bool all_sym = true; + for (uint32_t n_id : link_list) { + const LinkList &neigh_list = getLinkList(n_id, 0); + if (! neigh_list.has_link_to(id)) { +#ifdef KEEP_SYM + fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id); +#endif + all_sym = false; + } + } + if (all_sym) ++symmetrics; + if (levels < 2) continue; + const LinkList &link_list_1 = getLinkList(id, 1); + for (uint32_t n_id : link_list_1) { + ++level1links; + if (link_list.has_link_to(n_id)) ++both_l_links; + } + } + for (uint32_t l = 0; l < levelCounts.size(); ++l) { + fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]); + } + fprintf(stderr, "reachable nodes %u / %zu\n", + count_reachable(), _nodes.size() - levelCounts[0]); + fprintf(stderr, "level 1 links overlapping on l0: %u / total: %u\n", + both_l_links, level1links); + for (uint32_t l = 0; l < outLinkHist.size(); ++l) { + if (outLinkHist[l] != 0) { + fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]); + } + } + fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics); +} + +std::unique_ptr<NNS<float>> +make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva) +{ + return std::make_unique<HnswLikeNns>(numDims, dva); +} |