summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-hnswlike-nns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-hnswlike-nns.cpp')
-rw-r--r--eval/src/tests/ann/xp-hnswlike-nns.cpp161
1 files changed, 54 insertions, 107 deletions
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp
index 5cdbdd8efa3..72b3fdb21f9 100644
--- a/eval/src/tests/ann/xp-hnswlike-nns.cpp
+++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp
@@ -7,31 +7,13 @@
#include "std-random.h"
#include "nns.h"
-/*
- Todo:
-
- measure effect of:
- 1) removing leftover backlinks during "shrink" operation
- 2) refilling to low-watermark after 1) happens
- 3) refilling to mid-watermark after 1) happens
- 4) adding then removing 20% extra documents
- 5) removing 20% first-added documents
- 6) removing first-added documents while inserting new ones
-
- 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100
- 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100
- 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids
-
- 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k)
- */
-
-static size_t distcalls_simple;
-static size_t distcalls_search_layer;
-static size_t distcalls_other;
-static size_t distcalls_heuristic;
-static size_t distcalls_shrink;
-static size_t distcalls_refill;
-static size_t refill_needed_calls;
+static uint64_t distcalls_simple;
+static uint64_t distcalls_search_layer;
+static uint64_t distcalls_other;
+static uint64_t distcalls_heuristic;
+static uint64_t distcalls_shrink;
+static uint64_t distcalls_refill;
+static uint64_t refill_needed_calls;
struct LinkList : std::vector<uint32_t>
{
@@ -49,7 +31,6 @@ struct LinkList : std::vector<uint32_t>
}
}
fprintf(stderr, "BAD missing link to remove: %u\n", id);
- abort();
}
};
@@ -149,7 +130,6 @@ private:
double _levelMultiplier;
RndGen _rndGen;
VisitedSetPool _visitedSetPool;
- size_t _ops_counter;
double distance(Vector v, uint32_t id) const;
@@ -164,7 +144,6 @@ private:
return (int) r;
}
- uint32_t count_reachable() const;
void dumpStats() const;
public:
@@ -176,9 +155,9 @@ public:
_M(16),
_efConstruction(200),
_levelMultiplier(1.0 / log(1.0 * _M)),
- _rndGen(),
- _ops_counter(0)
+ _rndGen()
{
+ _nodes.reserve(1234567);
}
~HnswLikeNns() { dumpStats(); }
@@ -238,7 +217,6 @@ public:
if (_entryLevel < 0) {
_entryId = docid;
_entryLevel = level;
- track_ops();
return;
}
int searchLevel = _entryLevel;
@@ -263,23 +241,18 @@ public:
_entryLevel = level;
_entryId = docid;
}
- track_ops();
- }
-
- void track_ops() {
- _ops_counter++;
- if ((_ops_counter % 10000) == 0) {
- double div = _ops_counter;
- fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
- fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
- fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
- fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
- fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
- fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
- fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
- fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ if (_nodes.size() % 10000 == 0) {
+ double div = _nodes.size();
+ fprintf(stderr, "added docs: %d\n", (int)div);
+ fprintf(stderr, "distance calls for layer: %" PRIu64 " is %.3f per doc\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %" PRIu64 " is %.3f per doc\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %" PRIu64 " is %.3f per doc\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %" PRIu64 " is %.3f per doc\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %" PRIu64 " is %.3f per doc\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %" PRIu64 " is %.3f per doc\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %" PRIu64 " is %.3f per doc\n", refill_needed_calls, refill_needed_calls / div);
}
- }
+ }
void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
LinkList &links = getLinkList(from_id, level);
@@ -294,10 +267,9 @@ public:
if (repl_id == my_id) continue;
if (my_links.has_link_to(repl_id)) continue;
LinkList &other_links = getLinkList(repl_id, level);
- if (other_links.size() + 1 >= _M) continue;
+ if (other_links.size() >= _M) continue;
other_links.push_back(my_id);
my_links.push_back(repl_id);
- if (my_links.size() >= _M) return;
}
}
}
@@ -327,17 +299,14 @@ public:
Node &node = _nodes[docid];
bool need_new_entrypoint = (docid == _entryId);
for (int level = node._links.size(); level-- > 0; ) {
- LinkList my_links;
- my_links.swap(node._links[level]);
+ const LinkList &my_links = node._links[level];
for (uint32_t n_id : my_links) {
if (need_new_entrypoint) {
_entryId = n_id;
_entryLevel = level;
- need_new_entrypoint = false;
+ need_new_entrypoint = false;
}
remove_link_from(n_id, docid, level);
- }
- for (uint32_t n_id : my_links) {
refill_ifneeded(n_id, my_links, level);
}
}
@@ -353,7 +322,6 @@ public:
}
}
}
- track_ops();
}
std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
@@ -363,12 +331,12 @@ public:
++distcalls_other;
HnswHit entryPoint(_entryId, SqDist(entryDist));
int searchLevel = _entryLevel;
- FurthestPriQ w;
- w.push(entryPoint);
while (searchLevel > 0) {
- search_layer(vector, w, std::min(k, search_k), searchLevel);
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
--searchLevel;
}
+ FurthestPriQ w;
+ w.push(entryPoint);
search_layer(vector, w, std::max(k, search_k), 0);
while (w.size() > k) {
w.pop();
@@ -522,87 +490,66 @@ HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t l
}
}
-uint32_t
-HnswLikeNns::count_reachable() const {
- VisitedSet visited(_nodes.size());
- int level = _entryLevel;
- LinkList curList;
- curList.push_back(_entryId);
- visited.mark(_entryId);
- uint32_t idx = 0;
- while (level >= 0) {
- while (idx < curList.size()) {
- uint32_t id = curList[idx++];
- const LinkList &links = getLinkList(id, level);
- for (uint32_t n_id : links) {
- if (visited.isMarked(n_id)) continue;
- visited.mark(n_id);
- curList.push_back(n_id);
- }
- }
- --level;
- idx = 0;
- }
- return curList.size();
-}
-
void
HnswLikeNns::dumpStats() const {
+ std::vector<uint32_t> inLinkCounters;
+ inLinkCounters.resize(_nodes.size());
+ std::vector<uint32_t> outLinkCounters;
+ outLinkCounters.resize(_nodes.size());
std::vector<uint32_t> levelCounts;
levelCounts.resize(_entryLevel + 2);
std::vector<uint32_t> outLinkHist;
outLinkHist.resize(2 * _M + 2);
- uint32_t symmetrics = 0;
- uint32_t level1links = 0;
- uint32_t both_l_links = 0;
fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n",
_nodes.size(), _entryLevel, _entryId);
-
for (uint32_t id = 0; id < _nodes.size(); ++id) {
const auto &node = _nodes[id];
uint32_t levels = node._links.size();
levelCounts[levels]++;
if (levels < 1) {
+ outLinkCounters[id] = 0;
outLinkHist[0]++;
continue;
}
const LinkList &link_list = getLinkList(id, 0);
uint32_t numlinks = link_list.size();
+ outLinkCounters[id] = numlinks;
outLinkHist[numlinks]++;
- if (numlinks < 1) {
+ if (numlinks < 2) {
fprintf(stderr, "node with %u links: id %u\n", numlinks, id);
- }
- bool all_sym = true;
- for (uint32_t n_id : link_list) {
- const LinkList &neigh_list = getLinkList(n_id, 0);
- if (! neigh_list.has_link_to(id)) {
- fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id);
- all_sym = false;
+ for (uint32_t n_id : link_list) {
+ const LinkList &neigh_list = getLinkList(n_id, 0);
+ fprintf(stderr, "neighbor id %u has %zu links\n", n_id, neigh_list.size());
+ if (! neigh_list.has_link_to(id)) {
+ fprintf(stderr, "BAD neighbor %u is missing backlink\n", n_id);
+ }
}
}
- if (all_sym) ++symmetrics;
- if (levels < 2) continue;
- const LinkList &link_list_1 = getLinkList(id, 1);
- for (uint32_t n_id : link_list_1) {
- ++level1links;
- if (link_list.has_link_to(n_id)) ++both_l_links;
+ for (uint32_t n_id : link_list) {
+ inLinkCounters[n_id]++;
}
}
for (uint32_t l = 0; l < levelCounts.size(); ++l) {
fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]);
}
- fprintf(stderr, "reachable nodes %u / %zu\n",
- count_reachable(), _nodes.size() - levelCounts[0]);
- fprintf(stderr, "level 1 links overlapping on l0: %u / total: %u\n",
- both_l_links, level1links);
for (uint32_t l = 0; l < outLinkHist.size(); ++l) {
- if (outLinkHist[l] != 0) {
- fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]);
- }
+ fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]);
+ }
+ uint32_t symmetrics = 0;
+ std::vector<uint32_t> inLinkHist;
+ for (uint32_t id = 0; id < _nodes.size(); ++id) {
+ uint32_t cnt = inLinkCounters[id];
+ while (cnt >= inLinkHist.size()) inLinkHist.push_back(0);
+ inLinkHist[cnt]++;
+ if (cnt == outLinkCounters[id]) ++symmetrics;
+ }
+ for (uint32_t l = 0; l < inLinkHist.size(); ++l) {
+ fprintf(stderr, "Nodes with %u inward links on L0: %u\n", l, inLinkHist[l]);
}
fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics);
}
+
std::unique_ptr<NNS<float>>
make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva)
{