diff options
Diffstat (limited to 'eval/src/tests/ann/xp-hnswlike-nns.cpp')
-rw-r--r-- | eval/src/tests/ann/xp-hnswlike-nns.cpp | 161 |
1 files changed, 54 insertions, 107 deletions
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp index 5cdbdd8efa3..72b3fdb21f9 100644 --- a/eval/src/tests/ann/xp-hnswlike-nns.cpp +++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp @@ -7,31 +7,13 @@ #include "std-random.h" #include "nns.h" -/* - Todo: - - measure effect of: - 1) removing leftover backlinks during "shrink" operation - 2) refilling to low-watermark after 1) happens - 3) refilling to mid-watermark after 1) happens - 4) adding then removing 20% extra documents - 5) removing 20% first-added documents - 6) removing first-added documents while inserting new ones - - 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100 - 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100 - 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids - - 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k) - */ - -static size_t distcalls_simple; -static size_t distcalls_search_layer; -static size_t distcalls_other; -static size_t distcalls_heuristic; -static size_t distcalls_shrink; -static size_t distcalls_refill; -static size_t refill_needed_calls; +static uint64_t distcalls_simple; +static uint64_t distcalls_search_layer; +static uint64_t distcalls_other; +static uint64_t distcalls_heuristic; +static uint64_t distcalls_shrink; +static uint64_t distcalls_refill; +static uint64_t refill_needed_calls; struct LinkList : std::vector<uint32_t> { @@ -49,7 +31,6 @@ struct LinkList : std::vector<uint32_t> } } fprintf(stderr, "BAD missing link to remove: %u\n", id); - abort(); } }; @@ -149,7 +130,6 @@ private: double _levelMultiplier; RndGen _rndGen; VisitedSetPool _visitedSetPool; - size_t _ops_counter; double distance(Vector v, uint32_t id) const; @@ -164,7 +144,6 @@ private: return (int) r; } - uint32_t count_reachable() const; void dumpStats() const; public: @@ -176,9 +155,9 @@ public: _M(16), _efConstruction(200), _levelMultiplier(1.0 / log(1.0 * _M)), - _rndGen(), - _ops_counter(0) + _rndGen() { + _nodes.reserve(1234567); } ~HnswLikeNns() { dumpStats(); } @@ -238,7 +217,6 @@ public: if (_entryLevel < 0) { _entryId = docid; _entryLevel = level; - track_ops(); return; } int searchLevel = _entryLevel; @@ -263,23 +241,18 @@ public: _entryLevel = level; _entryId = docid; } - track_ops(); - } - - void track_ops() { - _ops_counter++; - if ((_ops_counter % 10000) == 0) { - double div = _ops_counter; - fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); - fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); - fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); - fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); - fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); - fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); - fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); - fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + if (_nodes.size() % 10000 == 0) { + double div = _nodes.size(); + fprintf(stderr, "added docs: %d\n", (int)div); + fprintf(stderr, "distance calls for layer: %" PRIu64 " is %.3f per doc\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %" PRIu64 " is %.3f per doc\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %" PRIu64 " is %.3f per doc\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %" PRIu64 " is %.3f per doc\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %" PRIu64 " is %.3f per doc\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %" PRIu64 " is %.3f per doc\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %" PRIu64 " is %.3f per doc\n", refill_needed_calls, refill_needed_calls / div); } - } + } void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { LinkList &links = getLinkList(from_id, level); @@ -294,10 +267,9 @@ public: if (repl_id == my_id) continue; if (my_links.has_link_to(repl_id)) continue; LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() + 1 >= _M) continue; + if (other_links.size() >= _M) continue; other_links.push_back(my_id); my_links.push_back(repl_id); - if (my_links.size() >= _M) return; } } } @@ -327,17 +299,14 @@ public: Node &node = _nodes[docid]; bool need_new_entrypoint = (docid == _entryId); for (int level = node._links.size(); level-- > 0; ) { - LinkList my_links; - my_links.swap(node._links[level]); + const LinkList &my_links = node._links[level]; for (uint32_t n_id : my_links) { if (need_new_entrypoint) { _entryId = n_id; _entryLevel = level; - need_new_entrypoint = false; + need_new_entrypoint = false; } remove_link_from(n_id, docid, level); - } - for (uint32_t n_id : my_links) { refill_ifneeded(n_id, my_links, level); } } @@ -353,7 +322,6 @@ public: } } } - track_ops(); } std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override { @@ -363,12 +331,12 @@ public: ++distcalls_other; HnswHit entryPoint(_entryId, SqDist(entryDist)); int searchLevel = _entryLevel; - FurthestPriQ w; - w.push(entryPoint); while (searchLevel > 0) { - search_layer(vector, w, std::min(k, search_k), searchLevel); + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); --searchLevel; } + FurthestPriQ w; + w.push(entryPoint); search_layer(vector, w, std::max(k, search_k), 0); while (w.size() > k) { w.pop(); @@ -522,87 +490,66 @@ HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t l } } -uint32_t -HnswLikeNns::count_reachable() const { - VisitedSet visited(_nodes.size()); - int level = _entryLevel; - LinkList curList; - curList.push_back(_entryId); - visited.mark(_entryId); - uint32_t idx = 0; - while (level >= 0) { - while (idx < curList.size()) { - uint32_t id = curList[idx++]; - const LinkList &links = getLinkList(id, level); - for (uint32_t n_id : links) { - if (visited.isMarked(n_id)) continue; - visited.mark(n_id); - curList.push_back(n_id); - } - } - --level; - idx = 0; - } - return curList.size(); -} - void HnswLikeNns::dumpStats() const { + std::vector<uint32_t> inLinkCounters; + inLinkCounters.resize(_nodes.size()); + std::vector<uint32_t> outLinkCounters; + outLinkCounters.resize(_nodes.size()); std::vector<uint32_t> levelCounts; levelCounts.resize(_entryLevel + 2); std::vector<uint32_t> outLinkHist; outLinkHist.resize(2 * _M + 2); - uint32_t symmetrics = 0; - uint32_t level1links = 0; - uint32_t both_l_links = 0; fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n", _nodes.size(), _entryLevel, _entryId); - for (uint32_t id = 0; id < _nodes.size(); ++id) { const auto &node = _nodes[id]; uint32_t levels = node._links.size(); levelCounts[levels]++; if (levels < 1) { + outLinkCounters[id] = 0; outLinkHist[0]++; continue; } const LinkList &link_list = getLinkList(id, 0); uint32_t numlinks = link_list.size(); + outLinkCounters[id] = numlinks; outLinkHist[numlinks]++; - if (numlinks < 1) { + if (numlinks < 2) { fprintf(stderr, "node with %u links: id %u\n", numlinks, id); - } - bool all_sym = true; - for (uint32_t n_id : link_list) { - const LinkList &neigh_list = getLinkList(n_id, 0); - if (! neigh_list.has_link_to(id)) { - fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id); - all_sym = false; + for (uint32_t n_id : link_list) { + const LinkList &neigh_list = getLinkList(n_id, 0); + fprintf(stderr, "neighbor id %u has %zu links\n", n_id, neigh_list.size()); + if (! neigh_list.has_link_to(id)) { + fprintf(stderr, "BAD neighbor %u is missing backlink\n", n_id); + } } } - if (all_sym) ++symmetrics; - if (levels < 2) continue; - const LinkList &link_list_1 = getLinkList(id, 1); - for (uint32_t n_id : link_list_1) { - ++level1links; - if (link_list.has_link_to(n_id)) ++both_l_links; + for (uint32_t n_id : link_list) { + inLinkCounters[n_id]++; } } for (uint32_t l = 0; l < levelCounts.size(); ++l) { fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]); } - fprintf(stderr, "reachable nodes %u / %zu\n", - count_reachable(), _nodes.size() - levelCounts[0]); - fprintf(stderr, "level 1 links overlapping on l0: %u / total: %u\n", - both_l_links, level1links); for (uint32_t l = 0; l < outLinkHist.size(); ++l) { - if (outLinkHist[l] != 0) { - fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]); - } + fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]); + } + uint32_t symmetrics = 0; + std::vector<uint32_t> inLinkHist; + for (uint32_t id = 0; id < _nodes.size(); ++id) { + uint32_t cnt = inLinkCounters[id]; + while (cnt >= inLinkHist.size()) inLinkHist.push_back(0); + inLinkHist[cnt]++; + if (cnt == outLinkCounters[id]) ++symmetrics; + } + for (uint32_t l = 0; l < inLinkHist.size(); ++l) { + fprintf(stderr, "Nodes with %u inward links on L0: %u\n", l, inLinkHist[l]); } fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics); } + std::unique_ptr<NNS<float>> make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva) { |