summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-01-27 16:00:43 +0100
committerGitHub <noreply@github.com>2020-01-27 16:00:43 +0100
commitb4793a390bddf73fda1e9b5668619bcb518bd2a2 (patch)
tree7611bd4105b8dcb730c46acd91e46c1917c5733c
parent82a01821e331be871e606cb590ce7bcc2c5b60e6 (diff)
parent8d8f4c01794fa60d3ce6a5f36f32e7cf93aa376f (diff)
Merge pull request #11960 from vespa-engine/arnej/refactor-hnsw-algo-only
refactor and unify
-rw-r--r--eval/src/tests/ann/nns-l2.h1
-rw-r--r--eval/src/tests/ann/nns.h4
-rw-r--r--eval/src/tests/ann/sift_benchmark.cpp9
-rw-r--r--eval/src/tests/ann/xp-hnsw-wrap.cpp55
-rw-r--r--eval/src/tests/ann/xp-hnswlike-nns.cpp472
5 files changed, 385 insertions, 156 deletions
diff --git a/eval/src/tests/ann/nns-l2.h b/eval/src/tests/ann/nns-l2.h
index dcad5f1bda6..857866ff73b 100644
--- a/eval/src/tests/ann/nns-l2.h
+++ b/eval/src/tests/ann/nns-l2.h
@@ -2,6 +2,7 @@
#pragma once
#include <string.h>
+#include <vespa/vespalib/util/arrayref.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
template <typename T, size_t VLEN>
diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h
index 79c1aac4379..ffe2882188e 100644
--- a/eval/src/tests/ann/nns.h
+++ b/eval/src/tests/ann/nns.h
@@ -67,3 +67,7 @@ make_rplsh_nns(uint32_t numDims, const DocVectorAccess<float> &dva);
extern
std::unique_ptr<NNS<float>>
make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva);
+
+extern
+std::unique_ptr<NNS<float>>
+make_hnsw_wrap(uint32_t numDims, const DocVectorAccess<float> &dva);
diff --git a/eval/src/tests/ann/sift_benchmark.cpp b/eval/src/tests/ann/sift_benchmark.cpp
index 5ca505e5f1e..f3570eb9247 100644
--- a/eval/src/tests/ann/sift_benchmark.cpp
+++ b/eval/src/tests/ann/sift_benchmark.cpp
@@ -281,10 +281,17 @@ TEST("require that Annoy via NNS api mostly works") {
TEST("require that HNSW via NNS api mostly works") {
DocVectorAdapter adapter;
std::unique_ptr<NNS_API> nns = make_hnsw_nns(NUM_DIMS, adapter);
- benchmark_nns("HNSW", *nns, { 100, 200 });
+ benchmark_nns("HNSW-like", *nns, { 100, 150, 200 });
}
#endif
+#if 0
+TEST("require that HNSW wrapped api mostly works") {
+ DocVectorAdapter adapter;
+ std::unique_ptr<NNS_API> nns = make_hnsw_wrap(NUM_DIMS, adapter);
+ benchmark_nns("HNSW-wrap", *nns, { 100, 150, 200 });
+}
+#endif
/**
* Before running the benchmark the ANN_SIFT1M data set must be downloaded and extracted:
diff --git a/eval/src/tests/ann/xp-hnsw-wrap.cpp b/eval/src/tests/ann/xp-hnsw-wrap.cpp
new file mode 100644
index 00000000000..33895b2bd7c
--- /dev/null
+++ b/eval/src/tests/ann/xp-hnsw-wrap.cpp
@@ -0,0 +1,55 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "nns.h"
+#include <iostream>
+#include "/git/hnswlib/hnswlib/hnswlib.h"
+
+class HnswWrapNns : public NNS<float>
+{
+private:
+ using Implementation = hnswlib::HierarchicalNSW<float>;
+ hnswlib::L2Space _l2space;
+ Implementation _hnsw;
+
+public:
+ HnswWrapNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _l2space(numDims),
+ _hnsw(&_l2space, 1000000, 16, 200)
+ {
+ }
+
+ ~HnswWrapNns() {}
+
+ void addDoc(uint32_t docid) override {
+ Vector vector = _dva.get(docid);
+ _hnsw.addPoint(vector.cbegin(), docid);
+ }
+
+ void removeDoc(uint32_t docid) override {
+ _hnsw.markDelete(docid);
+ }
+
+ std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
+ std::vector<NnsHit> reversed;
+ auto priQ = _hnsw.searchKnn(vector.cbegin(), std::max(k, search_k));
+ while (! priQ.empty()) {
+ auto pair = priQ.top();
+ reversed.emplace_back(pair.second, SqDist(pair.first));
+ priQ.pop();
+ }
+ std::vector<NnsHit> result;
+ while (result.size() < k && !reversed.empty()) {
+ result.push_back(reversed.back());
+ reversed.pop_back();
+ }
+ return result;
+ }
+};
+
+std::unique_ptr<NNS<float>>
+make_hnsw_wrap(uint32_t numDims, const DocVectorAccess<float> &dva)
+{
+ NNS<float> *p = new HnswWrapNns(numDims, dva);
+ return std::unique_ptr<NNS<float>>(p);
+}
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp
index ec831610a71..5ab6bd4bdaf 100644
--- a/eval/src/tests/ann/xp-hnswlike-nns.cpp
+++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp
@@ -3,10 +3,35 @@
#include <algorithm>
#include <assert.h>
#include <queue>
-#include <random>
+#include "std-random.h"
#include "nns.h"
-using LinkList = std::vector<uint32_t>;
+static uint64_t distcalls_simple;
+static uint64_t distcalls_search_layer;
+static uint64_t distcalls_other;
+static uint64_t distcalls_heuristic;
+static uint64_t distcalls_shrink;
+static uint64_t distcalls_refill;
+static uint64_t refill_needed_calls;
+
+struct LinkList : std::vector<uint32_t>
+{
+ bool has_link_to(uint32_t id) const {
+ auto iter = std::find(begin(), end(), id);
+ return (iter != end());
+ }
+ void remove_link(uint32_t id) {
+ uint32_t last = back();
+ for (iterator iter = begin(); iter != end(); ++iter) {
+ if (*iter == id) {
+ *iter = last;
+ pop_back();
+ return;
+ }
+ }
+ fprintf(stderr, "BAD missing link to remove: %u\n", id);
+ }
+};
struct Node {
std::vector<LinkList> _links;
@@ -61,38 +86,36 @@ struct VisitedSetPool
};
struct HnswHit {
- float dist;
+ double dist;
uint32_t docid;
HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {}
};
-
-using QueueEntry = HnswHit;
struct GreaterDist {
- bool operator() (const QueueEntry &lhs, const QueueEntry& rhs) const {
+ bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
return (rhs.dist < lhs.dist);
}
};
struct LesserDist {
- bool operator() (const QueueEntry &lhs, const QueueEntry& rhs) const {
+ bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
return (lhs.dist < rhs.dist);
}
};
-using NearestList = std::vector<QueueEntry>;
+using NearestList = std::vector<HnswHit>;
-struct NearestPriQ : std::priority_queue<QueueEntry, NearestList, GreaterDist>
+struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist>
{
};
-struct FurthestPriQ : std::priority_queue<QueueEntry, NearestList, LesserDist>
+struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist>
{
- NearestList steal() {
- NearestList result;
- c.swap(result);
- return result;
- }
- const NearestList& peek() const { return c; }
+ NearestList steal() {
+ NearestList result;
+ c.swap(result);
+ return result;
+ }
+ const NearestList& peek() const { return c; }
};
class HnswLikeNns : public NNS<float>
@@ -104,7 +127,7 @@ private:
uint32_t _M;
uint32_t _efConstruction;
double _levelMultiplier;
- std::default_random_engine _rndGen;
+ RndGen _rndGen;
VisitedSetPool _visitedSetPool;
double distance(Vector v, uint32_t id) const;
@@ -115,11 +138,13 @@ private:
}
int randomLevel() {
- std::uniform_real_distribution<double> distribution(0.0, 1.0);
- double r = -log(distribution(_rndGen)) * _levelMultiplier;
+ double unif = _rndGen.nextUniform();
+ double r = -log(1.0-unif) * _levelMultiplier;
return (int) r;
}
+ void dumpStats() const;
+
public:
HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
: NNS(numDims, dva),
@@ -127,13 +152,14 @@ public:
_entryId(0),
_entryLevel(-1),
_M(16),
- _efConstruction(150),
- _levelMultiplier(1.0 / log(1.0 * _M))
+ _efConstruction(200),
+ _levelMultiplier(1.0 / log(1.0 * _M)),
+ _rndGen()
{
_nodes.reserve(1234567);
}
- ~HnswLikeNns() {}
+ ~HnswLikeNns() { dumpStats(); }
LinkList& getLinkList(uint32_t docid, uint32_t level) {
// assert(docid < _nodes.size());
@@ -141,85 +167,43 @@ public:
return _nodes[docid]._links[level];
}
+ const LinkList& getLinkList(uint32_t docid, uint32_t level) const {
+ return _nodes[docid]._links[level];
+ }
+
// simple greedy search
- QueueEntry search_layer_simple(Vector vector, QueueEntry curPoint, uint32_t searchLevel) {
+ HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
bool keepGoing = true;
while (keepGoing) {
keepGoing = false;
const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
for (uint32_t n_id : neighbors) {
double dist = distance(vector, n_id);
+ ++distcalls_simple;
if (dist < curPoint.dist) {
- curPoint = QueueEntry(n_id, SqDist(dist));
- keepGoing = true;
+ curPoint = HnswHit(n_id, SqDist(dist));
+ keepGoing = true;
}
}
}
return curPoint;
}
- void search_layer_foradd(Vector vector, FurthestPriQ &w,
- uint32_t ef, uint32_t searchLevel);
+ void search_layer(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel);
- FurthestPriQ search_layer(Vector vector, NearestList entryPoints,
- uint32_t ef, uint32_t searchLevel) {
- VisitedSet &visited = _visitedSetPool.get(_nodes.size());
- NearestPriQ candidates;
- FurthestPriQ w;
- for (auto point : entryPoints) {
- candidates.push(point);
- w.push(point);
- visited.mark(point.docid);
- }
- double limd = std::numeric_limits<double>::max();
- while (! candidates.empty()) {
- QueueEntry cand = candidates.top();
- candidates.pop();
- if (cand.dist > limd) {
- break;
- }
- for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
- if (visited.isMarked(e_id)) continue;
- visited.mark(e_id);
- double e_dist = distance(vector, e_id);
- if (e_dist < limd) {
- candidates.emplace(e_id, SqDist(e_dist));
- w.emplace(e_id, SqDist(e_dist));
- if (w.size() > ef) {
- w.pop();
- limd = w.top().dist;
- }
- }
- }
- }
- return w;
- }
-
- bool haveCloserDistance(QueueEntry e, const LinkList &r) const {
+ bool haveCloserDistance(HnswHit e, const LinkList &r) const {
for (uint32_t prevId : r) {
double dist = distance(e.docid, prevId);
+ ++distcalls_heuristic;
if (dist < e.dist) return true;
}
return false;
}
- LinkList select_neighbors(NearestPriQ &&w, uint32_t curMax) const;
+ LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const;
- LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) {
- if (neighbors.size() <= curMax) {
- LinkList result;
- result.reserve(curMax+1);
- for (const auto & entry : neighbors) {
- result.push_back(entry.docid);
- }
- return result;
- }
- NearestPriQ w;
- for (const QueueEntry & entry : neighbors) {
- w.push(entry);
- }
- return select_neighbors(std::move(w), curMax);
- }
+ LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const;
void addDoc(uint32_t docid) override {
Vector vector = _dva.get(docid);
@@ -236,7 +220,8 @@ public:
}
int searchLevel = _entryLevel;
double entryDist = distance(vector, _entryId);
- QueueEntry entryPoint(_entryId, SqDist(entryDist));
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
while (searchLevel > level) {
entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
--searchLevel;
@@ -245,9 +230,8 @@ public:
FurthestPriQ w;
w.push(entryPoint);
while (searchLevel >= 0) {
- search_layer_foradd(vector, w, _efConstruction, searchLevel);
- uint32_t maxLinks = (searchLevel > 0) ? _M : (2 * _M);
- LinkList neighbors = select_neighbors(w.peek(), maxLinks);
+ search_layer(vector, w, _efConstruction, searchLevel);
+ LinkList neighbors = select_neighbors(w.peek(), _M);
connect_new_node(docid, neighbors, searchLevel);
each_shrink_ifneeded(neighbors, searchLevel);
--searchLevel;
@@ -256,46 +240,111 @@ public:
_entryLevel = level;
_entryId = docid;
}
+ if (_nodes.size() % 10000 == 0) {
+ double div = _nodes.size();
+ fprintf(stderr, "added docs: %d\n", (int)div);
+ fprintf(stderr, "distance calls for layer: %zu is %.3f per doc\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %zu is %.3f per doc\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %zu is %.3f per doc\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %zu is %.3f per doc\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %zu is %.3f per doc\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %zu is %.3f per doc\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %zu is %.3f per doc\n", refill_needed_calls, refill_needed_calls / div);
+ }
+ }
+
+ void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
+ LinkList &links = getLinkList(from_id, level);
+ links.remove_link(remove_id);
+ }
+
+ void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() < 8) {
+ ++refill_needed_calls;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= _M) continue;
+ other_links.push_back(my_id);
+ my_links.push_back(repl_id);
+ }
+ }
}
void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level);
+ void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
+ LinkList &links = getLinkList(shrink_id, level);
+ NearestList distances;
+ for (uint32_t n_id : links) {
+ double n_dist = distance(shrink_id, n_id);
+ ++distcalls_shrink;
+ distances.emplace_back(n_id, SqDist(n_dist));
+ }
+ LinkList lostLinks;
+ LinkList oldLinks = links;
+ links = remove_weakest(distances, maxLinks, lostLinks);
+ for (uint32_t lost_id : lostLinks) {
+ remove_link_from(lost_id, shrink_id, level);
+ refill_ifneeded(lost_id, oldLinks, level);
+ }
+ }
+
void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
- void removeDoc(uint32_t ) override {
+ void removeDoc(uint32_t docid) override {
+ Node &node = _nodes[docid];
+ bool need_new_entrypoint = (docid == _entryId);
+ for (int level = node._links.size(); level-- > 0; ) {
+ const LinkList &my_links = node._links[level];
+ for (uint32_t n_id : my_links) {
+ if (need_new_entrypoint) {
+ _entryId = n_id;
+ _entryLevel = level;
+ need_new_entrypoint = false;
+ }
+ remove_link_from(n_id, docid, level);
+ refill_ifneeded(n_id, my_links, level);
+ }
+ }
+ node = Node(docid, 0, _M);
+ if (need_new_entrypoint) {
+ _entryLevel = -1;
+ _entryId = 0;
+ for (uint32_t i = 0; i < _nodes.size(); ++i) {
+ if (_nodes[i]._links.size() > 0) {
+ _entryId = i;
+ _entryLevel = _nodes[i]._links.size() - 1;
+ break;
+ }
+ }
+ }
}
std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
std::vector<NnsHit> result;
if (_entryLevel < 0) return result;
double entryDist = distance(vector, _entryId);
- QueueEntry entryPoint(_entryId, SqDist(entryDist));
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
int searchLevel = _entryLevel;
while (searchLevel > 0) {
entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
--searchLevel;
}
- NearestList entryPoints;
- entryPoints.push_back(entryPoint);
- FurthestPriQ w = search_layer(vector, entryPoints, std::max(k, search_k), 0);
- if (w.size() < k) {
- fprintf(stderr, "fewer than expected hits: k=%u, ks=%u, got=%zu\n",
- k, search_k, w.size());
- }
+ FurthestPriQ w;
+ w.push(entryPoint);
+ search_layer(vector, w, std::max(k, search_k), 0);
while (w.size() > k) {
w.pop();
}
- std::vector<QueueEntry> reversed;
- reversed.reserve(w.size());
- while (! w.empty()) {
- reversed.push_back(w.top());
- w.pop();
- }
- result.reserve(reversed.size());
- while (! reversed.empty()) {
- const QueueEntry &hit = reversed.back();
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(tmp.size());
+ for (const auto & hit : tmp) {
result.emplace_back(hit.docid, SqDist(hit.dist));
- reversed.pop_back();
}
return result;
}
@@ -310,78 +359,193 @@ HnswLikeNns::distance(Vector v, uint32_t b) const
void
HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) {
- uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
- for (uint32_t old_id : neighbors) {
- LinkList &oldLinks = getLinkList(old_id, level);
- if (oldLinks.size() > maxLinks) {
- NearestPriQ w;
- for (uint32_t n_id : oldLinks) {
- double n_dist = distance(old_id, n_id);
- w.emplace(n_id, SqDist(n_dist));
- }
- oldLinks = select_neighbors(std::move(w), maxLinks);
- }
+ uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ for (uint32_t old_id : neighbors) {
+ LinkList &oldLinks = getLinkList(old_id, level);
+ if (oldLinks.size() > maxLinks) {
+ shrink_links(old_id, maxLinks, level);
}
+ }
}
void
-HnswLikeNns::search_layer_foradd(Vector vector, FurthestPriQ &w,
- uint32_t ef, uint32_t searchLevel)
+HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel)
{
- NearestPriQ candidates;
- VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+ NearestPriQ candidates;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
- for (const QueueEntry& entry : w.peek()) {
- candidates.push(entry);
- visited.mark(entry.docid);
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
}
-
- double limd = std::numeric_limits<double>::max();
- while (! candidates.empty()) {
- QueueEntry cand = candidates.top();
- candidates.pop();
- if (cand.dist > limd) {
- break;
- }
- for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
- if (visited.isMarked(e_id)) continue;
- visited.mark(e_id);
- double e_dist = distance(vector, e_id);
- if (e_dist < limd) {
- candidates.emplace(e_id, SqDist(e_dist));
- w.emplace(e_id, SqDist(e_dist));
- if (w.size() > ef) {
- w.pop();
- limd = w.top().dist;
- }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
}
}
}
- return;
+ }
+ return;
}
LinkList
-HnswLikeNns::select_neighbors(NearestPriQ &&w, uint32_t curMax) const {
- LinkList result;
- result.reserve(curMax+1);
- while (! w.empty()) {
- QueueEntry e = w.top();
- w.pop();
- if (haveCloserDistance(e, result)) continue;
+HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (result.size() == curMax || haveCloserDistance(e, result)) {
+ lost.push_back(e.docid);
+ } else {
result.push_back(e.docid);
- if (result.size() >= curMax) break;
}
- return result;
+ }
+ return result;
}
+#ifdef NO_BACKFILL
+LinkList
+HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ bool needFiltering = (neighbors.size() > curMax);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (needFiltering && haveCloserDistance(e, result)) {
+ continue;
+ }
+ result.push_back(e.docid);
+ if (result.size() == curMax) return result;
+ }
+ return result;
+}
+#else
+LinkList
+HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ bool needFiltering = (neighbors.size() > curMax);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ LinkList backfill;
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (needFiltering && haveCloserDistance(e, result)) {
+ backfill.push_back(e.docid);
+ continue;
+ }
+ result.push_back(e.docid);
+ if (result.size() == curMax) return result;
+ }
+ if (result.size() * 4 < curMax) {
+ for (uint32_t fill_id : backfill) {
+ result.push_back(fill_id);
+ if (result.size() * 4 >= curMax) break;
+ }
+ }
+ return result;
+}
+#endif
+
void
HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level) {
- LinkList &newLinks = getLinkList(id, level);
- for (uint32_t neigh_id : neighbors) {
- LinkList &oldLinks = getLinkList(neigh_id, level);
- newLinks.push_back(neigh_id);
- oldLinks.push_back(id);
+ LinkList &newLinks = getLinkList(id, level);
+ for (uint32_t neigh_id : neighbors) {
+ LinkList &oldLinks = getLinkList(neigh_id, level);
+ newLinks.push_back(neigh_id);
+ oldLinks.push_back(id);
+ }
+}
+
+void
+HnswLikeNns::dumpStats() const {
+ std::vector<uint32_t> inLinkCounters;
+ inLinkCounters.resize(_nodes.size());
+ std::vector<uint32_t> outLinkCounters;
+ outLinkCounters.resize(_nodes.size());
+ std::vector<uint32_t> levelCounts;
+ levelCounts.resize(_entryLevel + 2);
+ std::vector<uint32_t> outLinkHist;
+ outLinkHist.resize(2 * _M + 2);
+ fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n",
+ _nodes.size(), _entryLevel, _entryId);
+ for (uint32_t id = 0; id < _nodes.size(); ++id) {
+ const auto &node = _nodes[id];
+ uint32_t levels = node._links.size();
+ levelCounts[levels]++;
+ if (levels < 1) {
+ outLinkCounters[id] = 0;
+ outLinkHist[0]++;
+ continue;
}
+ const LinkList &link_list = getLinkList(id, 0);
+ uint32_t numlinks = link_list.size();
+ outLinkCounters[id] = numlinks;
+ outLinkHist[numlinks]++;
+ if (numlinks < 2) {
+ fprintf(stderr, "node with %u links: id %u\n", numlinks, id);
+ for (uint32_t n_id : link_list) {
+ const LinkList &neigh_list = getLinkList(n_id, 0);
+ fprintf(stderr, "neighbor id %u has %zu links\n", n_id, neigh_list.size());
+ if (! neigh_list.has_link_to(id)) {
+ fprintf(stderr, "BAD neighbor %u is missing backlink\n", n_id);
+ }
+ }
+ }
+ for (uint32_t n_id : link_list) {
+ inLinkCounters[n_id]++;
+ }
+ }
+ for (uint32_t l = 0; l < levelCounts.size(); ++l) {
+ fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]);
+ }
+ for (uint32_t l = 0; l < outLinkHist.size(); ++l) {
+ fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]);
+ }
+ uint32_t symmetrics = 0;
+ std::vector<uint32_t> inLinkHist;
+ for (uint32_t id = 0; id < _nodes.size(); ++id) {
+ uint32_t cnt = inLinkCounters[id];
+ while (cnt >= inLinkHist.size()) inLinkHist.push_back(0);
+ inLinkHist[cnt]++;
+ if (cnt == outLinkCounters[id]) ++symmetrics;
+ }
+ for (uint32_t l = 0; l < inLinkHist.size(); ++l) {
+ fprintf(stderr, "Nodes with %u inward links on L0: %u\n", l, inLinkHist[l]);
+ }
+ fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics);
}
@@ -390,5 +554,3 @@ make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva)
{
return std::make_unique<HnswLikeNns>(numDims, dva);
}
-
-