aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/extended-hnsw.cpp
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-02-24 09:42:02 +0000
committerArne Juul <arnej@verizonmedia.com>2020-02-24 09:57:59 +0000
commitffa2293de302d99051f7fc97d29c4dc606f045f1 (patch)
tree95228696bf26ebf152c487c5d2fe189cd8dae078 /eval/src/tests/ann/extended-hnsw.cpp
parent00813e6561cae0365aad710d30a9bc0647e6a01f (diff)
experimental HNSW with various extensions
Diffstat (limited to 'eval/src/tests/ann/extended-hnsw.cpp')
-rw-r--r--eval/src/tests/ann/extended-hnsw.cpp830
1 files changed, 830 insertions, 0 deletions
diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp
new file mode 100644
index 00000000000..42f3a10b389
--- /dev/null
+++ b/eval/src/tests/ann/extended-hnsw.cpp
@@ -0,0 +1,830 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <algorithm>
+#include <assert.h>
+#include <queue>
+#include <cinttypes>
+#include "std-random.h"
+#include "nns.h"
+
+/*
+ Todo:
+
+ measure effect of:
+ 1) removing leftover backlinks during "shrink" operation
+ 2) refilling to low-watermark after 1) happens
+ 3) refilling to mid-watermark after 1) happens
+ 4) adding then removing 20% extra documents
+ 5) removing 20% first-added documents
+ 6) removing first-added documents while inserting new ones
+
+ 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100
+ 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100
+ 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids
+
+ 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k)
+ */
+
+static size_t distcalls_simple;
+static size_t distcalls_search_layer;
+static size_t distcalls_other;
+static size_t distcalls_heuristic;
+static size_t distcalls_shrink;
+static size_t distcalls_refill;
+static size_t refill_needed_calls;
+static size_t shrink_needed_calls;
+static size_t disconnected_weak_links;
+static size_t disconnected_for_symmetry;
+static size_t select_n_full;
+static size_t select_n_partial;
+
+struct LinkList : std::vector<uint32_t>
+{
+ bool has_link_to(uint32_t id) const {
+ auto iter = std::find(begin(), end(), id);
+ return (iter != end());
+ }
+ void remove_link(uint32_t id) {
+ uint32_t last = back();
+ for (iterator iter = begin(); iter != end(); ++iter) {
+ if (*iter == id) {
+ *iter = last;
+ pop_back();
+ return;
+ }
+ }
+ fprintf(stderr, "BAD missing link to remove: %u\n", id);
+ abort();
+ }
+};
+
+struct Node {
+ std::vector<LinkList> _links;
+ Node(uint32_t , uint32_t numLevels, uint32_t M)
+ : _links(numLevels)
+ {
+ for (uint32_t i = 0; i < _links.size(); ++i) {
+ _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1));
+ }
+ }
+};
+
+struct VisitedSet
+{
+ using Mark = unsigned short;
+ Mark *ptr;
+ Mark curval;
+ size_t sz;
+ VisitedSet(const VisitedSet &) = delete;
+ VisitedSet& operator=(const VisitedSet &) = delete;
+ explicit VisitedSet(size_t size) {
+ ptr = (Mark *)malloc(size * sizeof(Mark));
+ curval = -1;
+ sz = size;
+ clear();
+ }
+ void clear() {
+ ++curval;
+ if (curval == 0) {
+ memset(ptr, 0, sz * sizeof(Mark));
+ ++curval;
+ }
+ }
+ ~VisitedSet() { free(ptr); }
+ void mark(size_t id) { ptr[id] = curval; }
+ bool isMarked(size_t id) const { return ptr[id] == curval; }
+};
+
+struct VisitedSetPool
+{
+ std::unique_ptr<VisitedSet> lastUsed;
+ VisitedSetPool() {
+ lastUsed = std::make_unique<VisitedSet>(250);
+ }
+ ~VisitedSetPool() {}
+ VisitedSet &get(size_t size) {
+ if (size > lastUsed->sz) {
+ lastUsed = std::make_unique<VisitedSet>(size*2);
+ } else {
+ lastUsed->clear();
+ }
+ return *lastUsed;
+ }
+};
+
+struct HnswHit {
+ double dist;
+ uint32_t docid;
+ HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {}
+};
+
+struct GreaterDist {
+ bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
+ return (rhs.dist < lhs.dist);
+ }
+};
+struct LesserDist {
+ bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
+ return (lhs.dist < rhs.dist);
+ }
+};
+
+using NearestList = std::vector<HnswHit>;
+
+struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist>
+{
+};
+
+struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist>
+{
+ NearestList steal() {
+ NearestList result;
+ c.swap(result);
+ return result;
+ }
+ const NearestList& peek() const { return c; }
+};
+
+class HnswLikeNns : public NNS<float>
+{
+private:
+ std::vector<Node> _nodes;
+ uint32_t _entryId;
+ int _entryLevel;
+ uint32_t _M;
+ uint32_t _efConstruction;
+ double _levelMultiplier;
+ RndGen _rndGen;
+ VisitedSetPool _visitedSetPool;
+ size_t _ops_counter;
+
+ double distance(Vector v, uint32_t id) const;
+
+ double distance(uint32_t a, uint32_t b) const {
+ Vector v = _dva.get(a);
+ return distance(v, b);
+ }
+
+ int randomLevel() {
+ double unif = _rndGen.nextUniform();
+ double r = -log(1.0-unif) * _levelMultiplier;
+ return (int) r;
+ }
+
+ uint32_t count_reachable() const;
+ void dumpStats() const;
+
+public:
+ HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _nodes(),
+ _entryId(0),
+ _entryLevel(-1),
+ _M(16),
+ _efConstruction(200),
+ _levelMultiplier(1.0 / log(1.0 * _M)),
+ _rndGen(),
+ _ops_counter(0)
+ {
+ }
+
+ ~HnswLikeNns() { dumpStats(); }
+
+ LinkList& getLinkList(uint32_t docid, uint32_t level) {
+ // assert(docid < _nodes.size());
+ // assert(level < _nodes[docid]._links.size());
+ return _nodes[docid]._links[level];
+ }
+
+ const LinkList& getLinkList(uint32_t docid, uint32_t level) const {
+ return _nodes[docid]._links[level];
+ }
+
+ // simple greedy search
+ HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
+ bool keepGoing = true;
+ while (keepGoing) {
+ keepGoing = false;
+ const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
+ for (uint32_t n_id : neighbors) {
+ double dist = distance(vector, n_id);
+ ++distcalls_simple;
+ if (dist < curPoint.dist) {
+ curPoint = HnswHit(n_id, SqDist(dist));
+ keepGoing = true;
+ }
+ }
+ }
+ return curPoint;
+ }
+
+ void search_layer(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel);
+
+ void search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist);
+
+ bool haveCloserDistance(HnswHit e, const LinkList &r) const {
+ for (uint32_t prevId : r) {
+ double dist = distance(e.docid, prevId);
+ ++distcalls_heuristic;
+ if (dist < e.dist) return true;
+ }
+ return false;
+ }
+
+ LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const;
+
+ LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const;
+
+ void addDoc(uint32_t docid) override {
+ Vector vector = _dva.get(docid);
+ for (uint32_t id = _nodes.size(); id <= docid; ++id) {
+ _nodes.emplace_back(id, 0, _M);
+ }
+ int level = randomLevel();
+ assert(_nodes[docid]._links.size() == 0);
+ _nodes[docid] = Node(docid, level+1, _M);
+ if (_entryLevel < 0) {
+ _entryId = docid;
+ _entryLevel = level;
+ track_ops();
+ return;
+ }
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+#undef MULTI_ENTRY_I
+#ifdef MULTI_ENTRY_I
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > level) {
+ search_layer(vector, w, visited, 5 * _M, searchLevel);
+ --searchLevel;
+ }
+#else
+ while (searchLevel > level) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+#endif
+ searchLevel = std::min(level, _entryLevel);
+ while (searchLevel >= 0) {
+ search_layer(vector, w, visited, _efConstruction, searchLevel);
+ LinkList neighbors = select_neighbors(w.peek(), _M);
+ connect_new_node(docid, neighbors, searchLevel);
+ each_shrink_ifneeded(neighbors, searchLevel);
+ --searchLevel;
+ }
+ if (level > _entryLevel) {
+ _entryLevel = level;
+ _entryId = docid;
+ }
+ track_ops();
+ }
+
+ void track_ops() {
+ _ops_counter++;
+ if ((_ops_counter % 10000) == 0) {
+ double div = _ops_counter;
+ fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
+ fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
+ fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
+ fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
+ fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
+ }
+ }
+
+ void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
+ LinkList &links = getLinkList(from_id, level);
+ links.remove_link(remove_id);
+ }
+
+#undef SIMPLE_REFILL
+#ifdef SIMPLE_REFILL
+ void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() * 2 < _M) {
+ const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ ++refill_needed_calls;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= maxLinks) continue;
+ other_links.push_back(my_id);
+ my_links.push_back(repl_id);
+ if (my_links.size() >= _M) return;
+ }
+ }
+ }
+#else
+ void refill_all(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ NearestPriQ w;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ const LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= maxLinks) continue;
+ double dist = distance(my_id, repl_id);
+ ++distcalls_refill;
+ w.emplace(repl_id, SqDist(dist));
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, my_links)) continue;
+ LinkList &other_links = getLinkList(e.docid, level);
+ my_links.push_back(e.docid);
+ other_links.push_back(my_id);
+ if (my_links.size() == _M) break;
+ }
+ }
+ void refill_one(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ NearestPriQ w;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= _M) continue;
+ double dist = distance(my_id, repl_id);
+ ++distcalls_refill;
+ w.emplace(repl_id, SqDist(dist));
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, my_links)) continue;
+ LinkList &other_links = getLinkList(e.docid, level);
+ my_links.push_back(e.docid);
+ other_links.push_back(my_id);
+ return;
+ }
+ }
+ void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() < _M) {
+ ++refill_needed_calls;
+ refill_all(my_id, replacements, level);
+ }
+ }
+#endif
+
+ void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level);
+
+ void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
+ LinkList &links = getLinkList(shrink_id, level);
+ NearestList distances;
+ for (uint32_t n_id : links) {
+ double n_dist = distance(shrink_id, n_id);
+ ++distcalls_shrink;
+ distances.emplace_back(n_id, SqDist(n_dist));
+ }
+ LinkList lostLinks;
+ LinkList oldLinks = links;
+ links = remove_weakest(distances, maxLinks, lostLinks);
+#define KEEP_SYM
+#ifdef KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
+ ++disconnected_for_symmetry;
+ remove_link_from(lost_id, shrink_id, level);
+ }
+#define DO_REFILL_AFTER_KEEP_SYM
+#ifdef DO_REFILL_AFTER_KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
+#ifdef SIMPLE_REFILL
+ refill_ifneeded(lost_id, oldLinks, level);
+#else
+ refill_all(lost_id, oldLinks, level);
+#endif
+ }
+#endif
+#endif
+ }
+
+ void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
+
+ void mutually_reconnect(LinkList cluster, int level) {
+ while (! cluster.empty()) {
+ uint32_t n_id = cluster.back();
+ cluster.pop_back();
+#ifdef SIMPLE_REFILL
+ refill_ifneeded(n_id, cluster, level);
+#else
+ refill_all(n_id, cluster, level);
+#endif
+ }
+ }
+
+ void removeDoc(uint32_t docid) override {
+ Node &node = _nodes[docid];
+ bool need_new_entrypoint = (docid == _entryId);
+ for (int level = node._links.size(); level-- > 0; ) {
+ LinkList my_links;
+ my_links.swap(node._links[level]);
+ for (uint32_t n_id : my_links) {
+ if (need_new_entrypoint) {
+ _entryId = n_id;
+ _entryLevel = level;
+ need_new_entrypoint = false;
+ }
+ remove_link_from(n_id, docid, level);
+ }
+ mutually_reconnect(my_links, level);
+ }
+ node = Node(docid, 0, _M);
+ if (need_new_entrypoint) {
+ _entryLevel = -1;
+ _entryId = 0;
+ for (uint32_t i = 0; i < _nodes.size(); ++i) {
+ if (_nodes[i]._links.size() > 0) {
+ _entryId = i;
+ _entryLevel = _nodes[i]._links.size() - 1;
+ break;
+ }
+ }
+ }
+ track_ops();
+ }
+
+ std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+#undef MULTI_ENTRY_S
+#ifdef MULTI_ENTRY_S
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > 0) {
+ search_layer(vector, w, visited, std::min(k, search_k), searchLevel);
+ --searchLevel;
+ }
+#else
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+#endif
+ search_layer(vector, w, visited, std::max(k, search_k), 0);
+ while (w.size() > k) {
+ w.pop();
+ }
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(tmp.size());
+ for (const auto & hit : tmp) {
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ }
+ return result;
+ }
+
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override;
+};
+
+
+double
+HnswLikeNns::distance(Vector v, uint32_t b) const
+{
+ Vector w = _dva.get(b);
+ return l2distCalc.l2sq_dist(v, w);
+}
+
+std::vector<NnsHit>
+HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+#ifdef MULTI_ENTRY_S
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > 0) {
+ search_layer(vector, w, visited, std::min(k, search_k), searchLevel);
+ --searchLevel;
+ }
+#else
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+#endif
+ search_layer_with_filter(vector, w, visited, std::max(k, search_k), 0, blacklist);
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(std::min((size_t)k, tmp.size()));
+ for (const auto & hit : tmp) {
+ if (blacklist.isSet(hit.docid)) continue;
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ if (result.size() == k) break;
+ }
+ return result;
+}
+
+void
+HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) {
+ uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ for (uint32_t old_id : neighbors) {
+ LinkList &oldLinks = getLinkList(old_id, level);
+ if (oldLinks.size() > maxLinks) {
+ ++shrink_needed_calls;
+ shrink_links(old_id, maxLinks, level);
+ }
+ }
+}
+
+void
+HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel)
+{
+ NearestPriQ candidates;
+
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
+ }
+ }
+ }
+ }
+ return;
+}
+
+void
+HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist)
+{
+ NearestPriQ candidates;
+
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ if (blacklist.isSet(entry.docid)) ++ef;
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ if (blacklist.isSet(e_id)) continue;
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
+ }
+ }
+ }
+ }
+}
+
+LinkList
+HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (result.size() == curMax || haveCloserDistance(e, result)) {
+ lost.push_back(e.docid);
+ } else {
+ result.push_back(e.docid);
+ }
+ }
+ return result;
+}
+
+#define NO_BACKFILL
+#ifdef NO_BACKFILL
+LinkList
+HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, result)) {
+ continue;
+ }
+ result.push_back(e.docid);
+ if (result.size() == curMax) {
+ ++select_n_full;
+ return result;
+ }
+ }
+ ++select_n_partial;
+ return result;
+}
+#else
+LinkList
+HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ bool needFiltering = (neighbors.size() > curMax);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ LinkList backfill;
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (needFiltering && haveCloserDistance(e, result)) {
+ backfill.push_back(e.docid);
+ continue;
+ }
+ result.push_back(e.docid);
+ if (result.size() == curMax) return result;
+ }
+ if (result.size() * 4 < _M) {
+ for (uint32_t fill_id : backfill) {
+ result.push_back(fill_id);
+ if (result.size() * 2 >= _M) break;
+ }
+ }
+ return result;
+}
+#endif
+
+void
+HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level) {
+ LinkList &newLinks = getLinkList(id, level);
+ for (uint32_t neigh_id : neighbors) {
+ LinkList &oldLinks = getLinkList(neigh_id, level);
+ newLinks.push_back(neigh_id);
+ oldLinks.push_back(id);
+ }
+#define DISCONNECT_OLD_WEAK_LINKS
+#ifdef DISCONNECT_OLD_WEAK_LINKS
+ for (uint32_t i = 1; i < neighbors.size(); ++i) {
+ uint32_t n_1 = neighbors[i];
+ LinkList &links_1 = getLinkList(n_1, level);
+ for (uint32_t j = 0; j < i; ++j) {
+ uint32_t n_2 = neighbors[j];
+ if (links_1.has_link_to(n_2)) {
+ ++disconnected_weak_links;
+ LinkList &links_2 = getLinkList(n_2, level);
+ links_1.remove_link(n_2);
+ links_2.remove_link(n_1);
+ }
+ }
+ }
+#endif
+}
+
+uint32_t
+HnswLikeNns::count_reachable() const {
+ VisitedSet visited(_nodes.size());
+ int level = _entryLevel;
+ LinkList curList;
+ curList.push_back(_entryId);
+ visited.mark(_entryId);
+ uint32_t idx = 0;
+ while (level >= 0) {
+ while (idx < curList.size()) {
+ uint32_t id = curList[idx++];
+ const LinkList &links = getLinkList(id, level);
+ for (uint32_t n_id : links) {
+ if (visited.isMarked(n_id)) continue;
+ visited.mark(n_id);
+ curList.push_back(n_id);
+ }
+ }
+ --level;
+ idx = 0;
+ }
+ return curList.size();
+}
+
+void
+HnswLikeNns::dumpStats() const {
+ std::vector<uint32_t> levelCounts;
+ levelCounts.resize(_entryLevel + 2);
+ std::vector<uint32_t> outLinkHist;
+ outLinkHist.resize(2 * _M + 2);
+ uint32_t symmetrics = 0;
+ uint32_t level1links = 0;
+ uint32_t both_l_links = 0;
+ fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n",
+ _nodes.size(), _entryLevel, _entryId);
+
+ for (uint32_t id = 0; id < _nodes.size(); ++id) {
+ const auto &node = _nodes[id];
+ uint32_t levels = node._links.size();
+ levelCounts[levels]++;
+ if (levels < 1) {
+ outLinkHist[0]++;
+ continue;
+ }
+ const LinkList &link_list = getLinkList(id, 0);
+ uint32_t numlinks = link_list.size();
+ outLinkHist[numlinks]++;
+ if (numlinks < 1) {
+ fprintf(stderr, "node with %u links: id %u\n", numlinks, id);
+ }
+ bool all_sym = true;
+ for (uint32_t n_id : link_list) {
+ const LinkList &neigh_list = getLinkList(n_id, 0);
+ if (! neigh_list.has_link_to(id)) {
+#ifdef KEEP_SYM
+ fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id);
+#endif
+ all_sym = false;
+ }
+ }
+ if (all_sym) ++symmetrics;
+ if (levels < 2) continue;
+ const LinkList &link_list_1 = getLinkList(id, 1);
+ for (uint32_t n_id : link_list_1) {
+ ++level1links;
+ if (link_list.has_link_to(n_id)) ++both_l_links;
+ }
+ }
+ for (uint32_t l = 0; l < levelCounts.size(); ++l) {
+ fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]);
+ }
+ fprintf(stderr, "reachable nodes %u / %zu\n",
+ count_reachable(), _nodes.size() - levelCounts[0]);
+ fprintf(stderr, "level 1 links overlapping on l0: %u / total: %u\n",
+ both_l_links, level1links);
+ for (uint32_t l = 0; l < outLinkHist.size(); ++l) {
+ if (outLinkHist[l] != 0) {
+ fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]);
+ }
+ }
+ fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics);
+}
+
+std::unique_ptr<NNS<float>>
+make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva)
+{
+ return std::make_unique<HnswLikeNns>(numDims, dva);
+}