summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-02-25 14:40:02 +0000
committerArne Juul <arnej@verizonmedia.com>2020-02-25 14:47:36 +0000
commit16c477ad557d556fe4d63c871025f10b18aba84d (patch)
treedecad336db7e7513962978c9cc660ec407ded213 /eval
parent44fef3325d3e9bfa673d71b87721f0979f8404c8 (diff)
keep more code common
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/ann/CMakeLists.txt2
-rw-r--r--eval/src/tests/ann/extended-hnsw.cpp686
-rw-r--r--eval/src/tests/ann/xp-hnswlike-nns.cpp527
3 files changed, 427 insertions, 788 deletions
diff --git a/eval/src/tests/ann/CMakeLists.txt b/eval/src/tests/ann/CMakeLists.txt
index 34babf1412f..0ba38994c01 100644
--- a/eval/src/tests/ann/CMakeLists.txt
+++ b/eval/src/tests/ann/CMakeLists.txt
@@ -14,7 +14,7 @@ vespa_add_executable(eval_gist_benchmark_app
SOURCES
gist_benchmark.cpp
xp-annoy-nns.cpp
- xp-hnswlike-nns.cpp
+ extended-hnsw.cpp
xp-lsh-nns.cpp
DEPENDS
vespaeval
diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp
index 42f3a10b389..fbc4bedec05 100644
--- a/eval/src/tests/ann/extended-hnsw.cpp
+++ b/eval/src/tests/ann/extended-hnsw.cpp
@@ -1,29 +1,6 @@
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <algorithm>
-#include <assert.h>
-#include <queue>
-#include <cinttypes>
-#include "std-random.h"
-#include "nns.h"
-
-/*
- Todo:
-
- measure effect of:
- 1) removing leftover backlinks during "shrink" operation
- 2) refilling to low-watermark after 1) happens
- 3) refilling to mid-watermark after 1) happens
- 4) adding then removing 20% extra documents
- 5) removing 20% first-added documents
- 6) removing first-added documents while inserting new ones
-
- 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100
- 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100
- 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids
-
- 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k)
- */
+#include "hnsw-like.h"
static size_t distcalls_simple;
static size_t distcalls_search_layer;
@@ -38,471 +15,300 @@ static size_t disconnected_for_symmetry;
static size_t select_n_full;
static size_t select_n_partial;
-struct LinkList : std::vector<uint32_t>
-{
- bool has_link_to(uint32_t id) const {
- auto iter = std::find(begin(), end(), id);
- return (iter != end());
- }
- void remove_link(uint32_t id) {
- uint32_t last = back();
- for (iterator iter = begin(); iter != end(); ++iter) {
- if (*iter == id) {
- *iter = last;
- pop_back();
- return;
- }
- }
- fprintf(stderr, "BAD missing link to remove: %u\n", id);
- abort();
- }
-};
-
-struct Node {
- std::vector<LinkList> _links;
- Node(uint32_t , uint32_t numLevels, uint32_t M)
- : _links(numLevels)
- {
- for (uint32_t i = 0; i < _links.size(); ++i) {
- _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1));
- }
- }
-};
-
-struct VisitedSet
-{
- using Mark = unsigned short;
- Mark *ptr;
- Mark curval;
- size_t sz;
- VisitedSet(const VisitedSet &) = delete;
- VisitedSet& operator=(const VisitedSet &) = delete;
- explicit VisitedSet(size_t size) {
- ptr = (Mark *)malloc(size * sizeof(Mark));
- curval = -1;
- sz = size;
- clear();
- }
- void clear() {
- ++curval;
- if (curval == 0) {
- memset(ptr, 0, sz * sizeof(Mark));
- ++curval;
- }
- }
- ~VisitedSet() { free(ptr); }
- void mark(size_t id) { ptr[id] = curval; }
- bool isMarked(size_t id) const { return ptr[id] == curval; }
-};
-struct VisitedSetPool
+HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _nodes(),
+ _entryId(0),
+ _entryLevel(-1),
+ _M(16),
+ _efConstruction(200),
+ _levelMultiplier(1.0 / log(1.0 * _M)),
+ _rndGen(),
+ _ops_counter(0)
{
- std::unique_ptr<VisitedSet> lastUsed;
- VisitedSetPool() {
- lastUsed = std::make_unique<VisitedSet>(250);
- }
- ~VisitedSetPool() {}
- VisitedSet &get(size_t size) {
- if (size > lastUsed->sz) {
- lastUsed = std::make_unique<VisitedSet>(size*2);
- } else {
- lastUsed->clear();
- }
- return *lastUsed;
- }
-};
-
-struct HnswHit {
- double dist;
- uint32_t docid;
- HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {}
-};
-
-struct GreaterDist {
- bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
- return (rhs.dist < lhs.dist);
- }
-};
-struct LesserDist {
- bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
- return (lhs.dist < rhs.dist);
- }
-};
-
-using NearestList = std::vector<HnswHit>;
-
-struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist>
-{
-};
-
-struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist>
-{
- NearestList steal() {
- NearestList result;
- c.swap(result);
- return result;
- }
- const NearestList& peek() const { return c; }
-};
-
-class HnswLikeNns : public NNS<float>
-{
-private:
- std::vector<Node> _nodes;
- uint32_t _entryId;
- int _entryLevel;
- uint32_t _M;
- uint32_t _efConstruction;
- double _levelMultiplier;
- RndGen _rndGen;
- VisitedSetPool _visitedSetPool;
- size_t _ops_counter;
-
- double distance(Vector v, uint32_t id) const;
-
- double distance(uint32_t a, uint32_t b) const {
- Vector v = _dva.get(a);
- return distance(v, b);
- }
-
- int randomLevel() {
- double unif = _rndGen.nextUniform();
- double r = -log(1.0-unif) * _levelMultiplier;
- return (int) r;
- }
-
- uint32_t count_reachable() const;
- void dumpStats() const;
-
-public:
- HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
- : NNS(numDims, dva),
- _nodes(),
- _entryId(0),
- _entryLevel(-1),
- _M(16),
- _efConstruction(200),
- _levelMultiplier(1.0 / log(1.0 * _M)),
- _rndGen(),
- _ops_counter(0)
- {
- }
-
- ~HnswLikeNns() { dumpStats(); }
-
- LinkList& getLinkList(uint32_t docid, uint32_t level) {
- // assert(docid < _nodes.size());
- // assert(level < _nodes[docid]._links.size());
- return _nodes[docid]._links[level];
- }
-
- const LinkList& getLinkList(uint32_t docid, uint32_t level) const {
- return _nodes[docid]._links[level];
- }
+}
- // simple greedy search
- HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
- bool keepGoing = true;
- while (keepGoing) {
- keepGoing = false;
- const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
- for (uint32_t n_id : neighbors) {
- double dist = distance(vector, n_id);
- ++distcalls_simple;
- if (dist < curPoint.dist) {
- curPoint = HnswHit(n_id, SqDist(dist));
- keepGoing = true;
- }
+// simple greedy search
+HnswHit
+HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
+ bool keepGoing = true;
+ while (keepGoing) {
+ keepGoing = false;
+ const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
+ for (uint32_t n_id : neighbors) {
+ double dist = distance(vector, n_id);
+ ++distcalls_simple;
+ if (dist < curPoint.dist) {
+ curPoint = HnswHit(n_id, SqDist(dist));
+ keepGoing = true;
}
}
- return curPoint;
}
+ return curPoint;
+}
- void search_layer(Vector vector, FurthestPriQ &w,
- VisitedSet &visited,
- uint32_t ef, uint32_t searchLevel);
-
- void search_layer_with_filter(Vector vector, FurthestPriQ &w,
- VisitedSet &visited,
- uint32_t ef, uint32_t searchLevel,
- const BitVector &blacklist);
- bool haveCloserDistance(HnswHit e, const LinkList &r) const {
- for (uint32_t prevId : r) {
- double dist = distance(e.docid, prevId);
- ++distcalls_heuristic;
- if (dist < e.dist) return true;
- }
- return false;
+bool
+HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const {
+ for (uint32_t prevId : r) {
+ double dist = distance(e.docid, prevId);
+ ++distcalls_heuristic;
+ if (dist < e.dist) return true;
}
+ return false;
+}
- LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const;
-
- LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const;
-
- void addDoc(uint32_t docid) override {
- Vector vector = _dva.get(docid);
- for (uint32_t id = _nodes.size(); id <= docid; ++id) {
- _nodes.emplace_back(id, 0, _M);
- }
- int level = randomLevel();
- assert(_nodes[docid]._links.size() == 0);
- _nodes[docid] = Node(docid, level+1, _M);
- if (_entryLevel < 0) {
- _entryId = docid;
- _entryLevel = level;
- track_ops();
- return;
- }
- int searchLevel = _entryLevel;
- VisitedSet &visited = _visitedSetPool.get(_nodes.size());
- double entryDist = distance(vector, _entryId);
- ++distcalls_other;
- HnswHit entryPoint(_entryId, SqDist(entryDist));
+void
+HnswLikeNns::addDoc(uint32_t docid) {
+ Vector vector = _dva.get(docid);
+ for (uint32_t id = _nodes.size(); id <= docid; ++id) {
+ _nodes.emplace_back(id, 0, _M);
+ }
+ int level = randomLevel();
+ assert(_nodes[docid]._links.size() == 0);
+ _nodes[docid] = Node(docid, level+1, _M);
+ if (_entryLevel < 0) {
+ _entryId = docid;
+ _entryLevel = level;
+ track_ops();
+ return;
+ }
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
#undef MULTI_ENTRY_I
#ifdef MULTI_ENTRY_I
- FurthestPriQ w;
- w.push(entryPoint);
- while (searchLevel > level) {
- search_layer(vector, w, visited, 5 * _M, searchLevel);
- --searchLevel;
- }
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > level) {
+ search_layer(vector, w, visited, 5 * _M, searchLevel);
+ --searchLevel;
+ }
#else
- while (searchLevel > level) {
- entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
- --searchLevel;
- }
- FurthestPriQ w;
- w.push(entryPoint);
+ while (searchLevel > level) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
#endif
- searchLevel = std::min(level, _entryLevel);
- while (searchLevel >= 0) {
- search_layer(vector, w, visited, _efConstruction, searchLevel);
- LinkList neighbors = select_neighbors(w.peek(), _M);
- connect_new_node(docid, neighbors, searchLevel);
- each_shrink_ifneeded(neighbors, searchLevel);
- --searchLevel;
- }
- if (level > _entryLevel) {
- _entryLevel = level;
- _entryId = docid;
- }
- track_ops();
+ searchLevel = std::min(level, _entryLevel);
+ while (searchLevel >= 0) {
+ search_layer(vector, w, visited, _efConstruction, searchLevel);
+ LinkList neighbors = select_neighbors(w.peek(), _M);
+ connect_new_node(docid, neighbors, searchLevel);
+ each_shrink_ifneeded(neighbors, searchLevel);
+ --searchLevel;
}
-
- void track_ops() {
- _ops_counter++;
- if ((_ops_counter % 10000) == 0) {
- double div = _ops_counter;
- fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
- fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
- fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
- fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
- fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
- fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
- fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
- fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
- fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
- fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
- fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
- fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
- }
- }
-
- void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
- LinkList &links = getLinkList(from_id, level);
- links.remove_link(remove_id);
+ if (level > _entryLevel) {
+ _entryLevel = level;
+ _entryId = docid;
}
+ track_ops();
+}
+
+void
+HnswLikeNns::track_ops() {
+ _ops_counter++;
+ if ((_ops_counter % 10000) == 0) {
+ double div = _ops_counter;
+ fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
+ fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
+ fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
+ fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
+ fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
+ }
+}
-#undef SIMPLE_REFILL
#ifdef SIMPLE_REFILL
- void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
- LinkList &my_links = getLinkList(my_id, level);
- if (my_links.size() * 2 < _M) {
- const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
- ++refill_needed_calls;
- for (uint32_t repl_id : replacements) {
- if (repl_id == my_id) continue;
- if (my_links.has_link_to(repl_id)) continue;
- LinkList &other_links = getLinkList(repl_id, level);
- if (other_links.size() >= maxLinks) continue;
- other_links.push_back(my_id);
- my_links.push_back(repl_id);
- if (my_links.size() >= _M) return;
- }
- }
- }
-#else
- void refill_all(uint32_t my_id, const LinkList &replacements, uint32_t level) {
- LinkList &my_links = getLinkList(my_id, level);
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() * 2 < _M) {
const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
- NearestPriQ w;
+ ++refill_needed_calls;
for (uint32_t repl_id : replacements) {
if (repl_id == my_id) continue;
if (my_links.has_link_to(repl_id)) continue;
- const LinkList &other_links = getLinkList(repl_id, level);
+ LinkList &other_links = getLinkList(repl_id, level);
if (other_links.size() >= maxLinks) continue;
- double dist = distance(my_id, repl_id);
- ++distcalls_refill;
- w.emplace(repl_id, SqDist(dist));
- }
- while (! w.empty()) {
- HnswHit e = w.top();
- w.pop();
- if (haveCloserDistance(e, my_links)) continue;
- LinkList &other_links = getLinkList(e.docid, level);
- my_links.push_back(e.docid);
other_links.push_back(my_id);
- if (my_links.size() == _M) break;
+ my_links.push_back(repl_id);
+ if (my_links.size() >= _M) return;
}
}
- void refill_one(uint32_t my_id, const LinkList &replacements, uint32_t level) {
- LinkList &my_links = getLinkList(my_id, level);
- NearestPriQ w;
- for (uint32_t repl_id : replacements) {
- if (repl_id == my_id) continue;
- if (my_links.has_link_to(repl_id)) continue;
- LinkList &other_links = getLinkList(repl_id, level);
- if (other_links.size() >= _M) continue;
- double dist = distance(my_id, repl_id);
- ++distcalls_refill;
- w.emplace(repl_id, SqDist(dist));
- }
- while (! w.empty()) {
- HnswHit e = w.top();
- w.pop();
- if (haveCloserDistance(e, my_links)) continue;
- LinkList &other_links = getLinkList(e.docid, level);
- my_links.push_back(e.docid);
- other_links.push_back(my_id);
- return;
- }
+}
+#endif
+
+#define REFILL_ALL
+#ifdef REFILL_ALL
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() >= _M) return;
+ ++refill_needed_calls;
+ const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ NearestPriQ w;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ const LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= maxLinks) continue;
+ double dist = distance(my_id, repl_id);
+ ++distcalls_refill;
+ w.emplace(repl_id, SqDist(dist));
}
- void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
- LinkList &my_links = getLinkList(my_id, level);
- if (my_links.size() < _M) {
- ++refill_needed_calls;
- refill_all(my_id, replacements, level);
- }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, my_links)) continue;
+ LinkList &other_links = getLinkList(e.docid, level);
+ my_links.push_back(e.docid);
+ other_links.push_back(my_id);
+ if (my_links.size() == _M) break;
}
+}
#endif
- void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level);
+#ifdef REFILL_ONE
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() >= _M) return;
+ ++refill_needed_calls;
+ NearestPriQ w;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= _M) continue;
+ double dist = distance(my_id, repl_id);
+ ++distcalls_refill;
+ w.emplace(repl_id, SqDist(dist));
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, my_links)) continue;
+ LinkList &other_links = getLinkList(e.docid, level);
+ my_links.push_back(e.docid);
+ other_links.push_back(my_id);
+ return;
+ }
+}
+#endif
- void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
- LinkList &links = getLinkList(shrink_id, level);
- NearestList distances;
- for (uint32_t n_id : links) {
- double n_dist = distance(shrink_id, n_id);
- ++distcalls_shrink;
- distances.emplace_back(n_id, SqDist(n_dist));
- }
- LinkList lostLinks;
- LinkList oldLinks = links;
- links = remove_weakest(distances, maxLinks, lostLinks);
+void
+HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
+ LinkList &links = getLinkList(shrink_id, level);
+ NearestList distances;
+ for (uint32_t n_id : links) {
+ double n_dist = distance(shrink_id, n_id);
+ ++distcalls_shrink;
+ distances.emplace_back(n_id, SqDist(n_dist));
+ }
+ LinkList lostLinks;
+ LinkList oldLinks = links;
+ links = remove_weakest(distances, maxLinks, lostLinks);
#define KEEP_SYM
#ifdef KEEP_SYM
- for (uint32_t lost_id : lostLinks) {
- ++disconnected_for_symmetry;
- remove_link_from(lost_id, shrink_id, level);
- }
+ for (uint32_t lost_id : lostLinks) {
+ ++disconnected_for_symmetry;
+ remove_link_from(lost_id, shrink_id, level);
+ }
#define DO_REFILL_AFTER_KEEP_SYM
#ifdef DO_REFILL_AFTER_KEEP_SYM
- for (uint32_t lost_id : lostLinks) {
-#ifdef SIMPLE_REFILL
- refill_ifneeded(lost_id, oldLinks, level);
-#else
- refill_all(lost_id, oldLinks, level);
-#endif
- }
+ for (uint32_t lost_id : lostLinks) {
+ refill_ifneeded(lost_id, oldLinks, level);
+ }
#endif
#endif
- }
+}
- void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
- void mutually_reconnect(LinkList cluster, int level) {
- while (! cluster.empty()) {
- uint32_t n_id = cluster.back();
- cluster.pop_back();
-#ifdef SIMPLE_REFILL
- refill_ifneeded(n_id, cluster, level);
-#else
- refill_all(n_id, cluster, level);
-#endif
- }
- }
-
- void removeDoc(uint32_t docid) override {
- Node &node = _nodes[docid];
- bool need_new_entrypoint = (docid == _entryId);
- for (int level = node._links.size(); level-- > 0; ) {
- LinkList my_links;
- my_links.swap(node._links[level]);
- for (uint32_t n_id : my_links) {
- if (need_new_entrypoint) {
- _entryId = n_id;
- _entryLevel = level;
- need_new_entrypoint = false;
- }
- remove_link_from(n_id, docid, level);
+void
+HnswLikeNns::removeDoc(uint32_t docid) {
+ Node &node = _nodes[docid];
+ bool need_new_entrypoint = (docid == _entryId);
+ for (int level = node._links.size(); level-- > 0; ) {
+ LinkList my_links;
+ my_links.swap(node._links[level]);
+ for (uint32_t n_id : my_links) {
+ if (need_new_entrypoint) {
+ _entryId = n_id;
+ _entryLevel = level;
+ need_new_entrypoint = false;
}
- mutually_reconnect(my_links, level);
- }
- node = Node(docid, 0, _M);
- if (need_new_entrypoint) {
- _entryLevel = -1;
- _entryId = 0;
- for (uint32_t i = 0; i < _nodes.size(); ++i) {
- if (_nodes[i]._links.size() > 0) {
- _entryId = i;
- _entryLevel = _nodes[i]._links.size() - 1;
- break;
- }
+ remove_link_from(n_id, docid, level);
+ }
+ while (! my_links.empty()) {
+ uint32_t n_id = my_links.back();
+ my_links.pop_back();
+ refill_ifneeded(n_id, my_links, level);
+ }
+ }
+ node = Node(docid, 0, _M);
+ if (need_new_entrypoint) {
+ _entryLevel = -1;
+ _entryId = 0;
+ for (uint32_t i = 0; i < _nodes.size(); ++i) {
+ if (_nodes[i]._links.size() > 0) {
+ _entryId = i;
+ _entryLevel = _nodes[i]._links.size() - 1;
+ break;
}
}
- track_ops();
}
+ track_ops();
+}
- std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
- std::vector<NnsHit> result;
- if (_entryLevel < 0) return result;
- double entryDist = distance(vector, _entryId);
- ++distcalls_other;
- HnswHit entryPoint(_entryId, SqDist(entryDist));
- int searchLevel = _entryLevel;
- VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+std::vector<NnsHit>
+HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) {
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
#undef MULTI_ENTRY_S
#ifdef MULTI_ENTRY_S
- FurthestPriQ w;
- w.push(entryPoint);
- while (searchLevel > 0) {
- search_layer(vector, w, visited, std::min(k, search_k), searchLevel);
- --searchLevel;
- }
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > 0) {
+ search_layer(vector, w, visited, std::min(k, search_k), searchLevel);
+ --searchLevel;
+ }
#else
- while (searchLevel > 0) {
- entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
- --searchLevel;
- }
- FurthestPriQ w;
- w.push(entryPoint);
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
#endif
- search_layer(vector, w, visited, std::max(k, search_k), 0);
- while (w.size() > k) {
- w.pop();
- }
- NearestList tmp = w.steal();
- std::sort(tmp.begin(), tmp.end(), LesserDist());
- result.reserve(tmp.size());
- for (const auto & hit : tmp) {
- result.emplace_back(hit.docid, SqDist(hit.dist));
- }
- return result;
+ search_layer(vector, w, visited, std::max(k, search_k), 0);
+ while (w.size() > k) {
+ w.pop();
}
-
- std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override;
-};
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(tmp.size());
+ for (const auto & hit : tmp) {
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ }
+ return result;
+}
double
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp
index 90fc0fe2e92..b7cae9f731c 100644
--- a/eval/src/tests/ann/xp-hnswlike-nns.cpp
+++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp
@@ -1,11 +1,6 @@
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <algorithm>
-#include <assert.h>
-#include <queue>
-#include <cinttypes>
-#include "std-random.h"
-#include "nns.h"
+#include "hnsw-like.h"
/*
Todo:
@@ -38,378 +33,216 @@ static size_t disconnected_for_symmetry;
static size_t select_n_full;
static size_t select_n_partial;
-struct LinkList : std::vector<uint32_t>
-{
- bool has_link_to(uint32_t id) const {
- auto iter = std::find(begin(), end(), id);
- return (iter != end());
- }
- void remove_link(uint32_t id) {
- uint32_t last = back();
- for (iterator iter = begin(); iter != end(); ++iter) {
- if (*iter == id) {
- *iter = last;
- pop_back();
- return;
- }
- }
- fprintf(stderr, "BAD missing link to remove: %u\n", id);
- abort();
- }
-};
-struct Node {
- std::vector<LinkList> _links;
- Node(uint32_t , uint32_t numLevels, uint32_t M)
- : _links(numLevels)
- {
- for (uint32_t i = 0; i < _links.size(); ++i) {
- _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1));
- }
- }
-};
-struct VisitedSet
+HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _nodes(),
+ _entryId(0),
+ _entryLevel(-1),
+ _M(16),
+ _efConstruction(200),
+ _levelMultiplier(1.0 / log(1.0 * _M)),
+ _rndGen(),
+ _ops_counter(0)
{
- using Mark = unsigned short;
- Mark *ptr;
- Mark curval;
- size_t sz;
- VisitedSet(const VisitedSet &) = delete;
- VisitedSet& operator=(const VisitedSet &) = delete;
- explicit VisitedSet(size_t size) {
- ptr = (Mark *)malloc(size * sizeof(Mark));
- curval = -1;
- sz = size;
- clear();
- }
- void clear() {
- ++curval;
- if (curval == 0) {
- memset(ptr, 0, sz * sizeof(Mark));
- ++curval;
- }
- }
- ~VisitedSet() { free(ptr); }
- void mark(size_t id) { ptr[id] = curval; }
- bool isMarked(size_t id) const { return ptr[id] == curval; }
-};
+}
-struct VisitedSetPool
-{
- std::unique_ptr<VisitedSet> lastUsed;
- VisitedSetPool() {
- lastUsed = std::make_unique<VisitedSet>(250);
- }
- ~VisitedSetPool() {}
- VisitedSet &get(size_t size) {
- if (size > lastUsed->sz) {
- lastUsed = std::make_unique<VisitedSet>(size*2);
- } else {
- lastUsed->clear();
+// simple greedy search
+HnswHit
+HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
+ bool keepGoing = true;
+ while (keepGoing) {
+ keepGoing = false;
+ const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
+ for (uint32_t n_id : neighbors) {
+ double dist = distance(vector, n_id);
+ ++distcalls_simple;
+ if (dist < curPoint.dist) {
+ curPoint = HnswHit(n_id, SqDist(dist));
+ keepGoing = true;
+ }
}
- return *lastUsed;
- }
-};
-
-struct HnswHit {
- double dist;
- uint32_t docid;
- HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {}
-};
-
-struct GreaterDist {
- bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
- return (rhs.dist < lhs.dist);
- }
-};
-struct LesserDist {
- bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
- return (lhs.dist < rhs.dist);
- }
-};
-
-using NearestList = std::vector<HnswHit>;
-
-struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist>
-{
-};
-
-struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist>
-{
- NearestList steal() {
- NearestList result;
- c.swap(result);
- return result;
- }
- const NearestList& peek() const { return c; }
-};
-
-class HnswLikeNns : public NNS<float>
-{
-private:
- std::vector<Node> _nodes;
- uint32_t _entryId;
- int _entryLevel;
- uint32_t _M;
- uint32_t _efConstruction;
- double _levelMultiplier;
- RndGen _rndGen;
- VisitedSetPool _visitedSetPool;
- size_t _ops_counter;
-
- double distance(Vector v, uint32_t id) const;
-
- double distance(uint32_t a, uint32_t b) const {
- Vector v = _dva.get(a);
- return distance(v, b);
- }
-
- int randomLevel() {
- double unif = _rndGen.nextUniform();
- double r = -log(1.0-unif) * _levelMultiplier;
- return (int) r;
- }
-
- uint32_t count_reachable() const;
- void dumpStats() const;
-
-public:
- HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
- : NNS(numDims, dva),
- _nodes(),
- _entryId(0),
- _entryLevel(-1),
- _M(16),
- _efConstruction(200),
- _levelMultiplier(1.0 / log(1.0 * _M)),
- _rndGen(),
- _ops_counter(0)
- {
}
+ return curPoint;
+}
- ~HnswLikeNns() { dumpStats(); }
-
- LinkList& getLinkList(uint32_t docid, uint32_t level) {
- // assert(docid < _nodes.size());
- // assert(level < _nodes[docid]._links.size());
- return _nodes[docid]._links[level];
+bool
+HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const {
+ for (uint32_t prevId : r) {
+ double dist = distance(e.docid, prevId);
+ ++distcalls_heuristic;
+ if (dist < e.dist) return true;
}
+ return false;
+}
- const LinkList& getLinkList(uint32_t docid, uint32_t level) const {
- return _nodes[docid]._links[level];
+void
+HnswLikeNns::addDoc(uint32_t docid) {
+ Vector vector = _dva.get(docid);
+ for (uint32_t id = _nodes.size(); id <= docid; ++id) {
+ _nodes.emplace_back(id, 0, _M);
+ }
+ int level = randomLevel();
+ assert(_nodes[docid]._links.size() == 0);
+ _nodes[docid] = Node(docid, level+1, _M);
+ if (_entryLevel < 0) {
+ _entryId = docid;
+ _entryLevel = level;
+ track_ops();
+ return;
}
-
- // simple greedy search
- HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
- bool keepGoing = true;
- while (keepGoing) {
- keepGoing = false;
- const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
- for (uint32_t n_id : neighbors) {
- double dist = distance(vector, n_id);
- ++distcalls_simple;
- if (dist < curPoint.dist) {
- curPoint = HnswHit(n_id, SqDist(dist));
- keepGoing = true;
- }
- }
- }
- return curPoint;
+ int searchLevel = _entryLevel;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ while (searchLevel > level) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
}
-
- void search_layer(Vector vector, FurthestPriQ &w,
- uint32_t ef, uint32_t searchLevel);
-
- void search_layer_with_filter(Vector vector, FurthestPriQ &w,
- uint32_t ef, uint32_t searchLevel,
- const BitVector &blacklist);
-
- bool haveCloserDistance(HnswHit e, const LinkList &r) const {
- for (uint32_t prevId : r) {
- double dist = distance(e.docid, prevId);
- ++distcalls_heuristic;
- if (dist < e.dist) return true;
- }
- return false;
+ searchLevel = std::min(level, _entryLevel);
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel >= 0) {
+ search_layer(vector, w, _efConstruction, searchLevel);
+ LinkList neighbors = select_neighbors(w.peek(), _M);
+ connect_new_node(docid, neighbors, searchLevel);
+ each_shrink_ifneeded(neighbors, searchLevel);
+ --searchLevel;
}
-
- LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const;
-
- LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const;
-
- void addDoc(uint32_t docid) override {
- Vector vector = _dva.get(docid);
- for (uint32_t id = _nodes.size(); id <= docid; ++id) {
- _nodes.emplace_back(id, 0, _M);
- }
- int level = randomLevel();
- assert(_nodes[docid]._links.size() == 0);
- _nodes[docid] = Node(docid, level+1, _M);
- if (_entryLevel < 0) {
- _entryId = docid;
- _entryLevel = level;
- track_ops();
- return;
- }
- int searchLevel = _entryLevel;
- double entryDist = distance(vector, _entryId);
- ++distcalls_other;
- HnswHit entryPoint(_entryId, SqDist(entryDist));
- while (searchLevel > level) {
- entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
- --searchLevel;
- }
- searchLevel = std::min(level, _entryLevel);
- FurthestPriQ w;
- w.push(entryPoint);
- while (searchLevel >= 0) {
- search_layer(vector, w, _efConstruction, searchLevel);
- LinkList neighbors = select_neighbors(w.peek(), _M);
- connect_new_node(docid, neighbors, searchLevel);
- each_shrink_ifneeded(neighbors, searchLevel);
- --searchLevel;
- }
- if (level > _entryLevel) {
- _entryLevel = level;
- _entryId = docid;
- }
- track_ops();
+ if (level > _entryLevel) {
+ _entryLevel = level;
+ _entryId = docid;
}
+ track_ops();
+}
- void track_ops() {
- _ops_counter++;
- if ((_ops_counter % 10000) == 0) {
- double div = _ops_counter;
- fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
- fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
- fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
- fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
- fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
- fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
- fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
- fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
- fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
- fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
- fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
- fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
- }
- }
-
- void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
- LinkList &links = getLinkList(from_id, level);
- links.remove_link(remove_id);
- }
+void
+HnswLikeNns::track_ops() {
+ _ops_counter++;
+ if ((_ops_counter % 10000) == 0) {
+ double div = _ops_counter;
+ fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
+ fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
+ fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
+ fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
+ fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
+ }
+}
- void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
- LinkList &my_links = getLinkList(my_id, level);
- if (my_links.size() < 8) {
- ++refill_needed_calls;
- for (uint32_t repl_id : replacements) {
- if (repl_id == my_id) continue;
- if (my_links.has_link_to(repl_id)) continue;
- LinkList &other_links = getLinkList(repl_id, level);
- if (other_links.size() + 1 >= _M) continue;
- other_links.push_back(my_id);
- my_links.push_back(repl_id);
- if (my_links.size() >= _M) return;
- }
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() < 8) {
+ ++refill_needed_calls;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() + 1 >= _M) continue;
+ other_links.push_back(my_id);
+ my_links.push_back(repl_id);
+ if (my_links.size() >= _M) return;
}
}
+}
- void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level);
-
- void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
- LinkList &links = getLinkList(shrink_id, level);
- NearestList distances;
- for (uint32_t n_id : links) {
- double n_dist = distance(shrink_id, n_id);
- ++distcalls_shrink;
- distances.emplace_back(n_id, SqDist(n_dist));
- }
- LinkList lostLinks;
- LinkList oldLinks = links;
- links = remove_weakest(distances, maxLinks, lostLinks);
+void
+HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
+ LinkList &links = getLinkList(shrink_id, level);
+ NearestList distances;
+ for (uint32_t n_id : links) {
+ double n_dist = distance(shrink_id, n_id);
+ ++distcalls_shrink;
+ distances.emplace_back(n_id, SqDist(n_dist));
+ }
+ LinkList lostLinks;
+ LinkList oldLinks = links;
+ links = remove_weakest(distances, maxLinks, lostLinks);
#define KEEP_SYM
#ifdef KEEP_SYM
- for (uint32_t lost_id : lostLinks) {
- ++disconnected_for_symmetry;
- remove_link_from(lost_id, shrink_id, level);
- }
+ for (uint32_t lost_id : lostLinks) {
+ ++disconnected_for_symmetry;
+ remove_link_from(lost_id, shrink_id, level);
+ }
#define DO_REFILL_AFTER_KEEP_SYM
#ifdef DO_REFILL_AFTER_KEEP_SYM
- for (uint32_t lost_id : lostLinks) {
- refill_ifneeded(lost_id, oldLinks, level);
- }
+ for (uint32_t lost_id : lostLinks) {
+ refill_ifneeded(lost_id, oldLinks, level);
+ }
#endif
#endif
- }
-
- void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
+}
- void removeDoc(uint32_t docid) override {
- Node &node = _nodes[docid];
- bool need_new_entrypoint = (docid == _entryId);
- for (int level = node._links.size(); level-- > 0; ) {
- LinkList my_links;
- my_links.swap(node._links[level]);
- for (uint32_t n_id : my_links) {
- if (need_new_entrypoint) {
- _entryId = n_id;
- _entryLevel = level;
- need_new_entrypoint = false;
- }
- remove_link_from(n_id, docid, level);
- }
- while (! my_links.empty()) {
- uint32_t n_id = my_links.back();
- my_links.pop_back();
- refill_ifneeded(n_id, my_links, level);
+void
+HnswLikeNns::removeDoc(uint32_t docid) {
+ Node &node = _nodes[docid];
+ bool need_new_entrypoint = (docid == _entryId);
+ for (int level = node._links.size(); level-- > 0; ) {
+ LinkList my_links;
+ my_links.swap(node._links[level]);
+ for (uint32_t n_id : my_links) {
+ if (need_new_entrypoint) {
+ _entryId = n_id;
+ _entryLevel = level;
+ need_new_entrypoint = false;
}
- }
- node = Node(docid, 0, _M);
- if (need_new_entrypoint) {
- _entryLevel = -1;
- _entryId = 0;
- for (uint32_t i = 0; i < _nodes.size(); ++i) {
- if (_nodes[i]._links.size() > 0) {
- _entryId = i;
- _entryLevel = _nodes[i]._links.size() - 1;
- break;
- }
+ remove_link_from(n_id, docid, level);
+ }
+ while (! my_links.empty()) {
+ uint32_t n_id = my_links.back();
+ my_links.pop_back();
+ refill_ifneeded(n_id, my_links, level);
+ }
+ }
+ node = Node(docid, 0, _M);
+ if (need_new_entrypoint) {
+ _entryLevel = -1;
+ _entryId = 0;
+ for (uint32_t i = 0; i < _nodes.size(); ++i) {
+ if (_nodes[i]._links.size() > 0) {
+ _entryId = i;
+ _entryLevel = _nodes[i]._links.size() - 1;
+ break;
}
}
- track_ops();
}
+ track_ops();
+}
- std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
- std::vector<NnsHit> result;
- if (_entryLevel < 0) return result;
- double entryDist = distance(vector, _entryId);
- ++distcalls_other;
- HnswHit entryPoint(_entryId, SqDist(entryDist));
- int searchLevel = _entryLevel;
- while (searchLevel > 0) {
- entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
- --searchLevel;
- }
- FurthestPriQ w;
- w.push(entryPoint);
- search_layer(vector, w, std::max(k, search_k), 0);
- while (w.size() > k) {
- w.pop();
- }
- NearestList tmp = w.steal();
- std::sort(tmp.begin(), tmp.end(), LesserDist());
- result.reserve(tmp.size());
- for (const auto & hit : tmp) {
- result.emplace_back(hit.docid, SqDist(hit.dist));
- }
- return result;
+std::vector<NnsHit>
+HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) {
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
}
-
- std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override;
-};
+ FurthestPriQ w;
+ w.push(entryPoint);
+ search_layer(vector, w, std::max(k, search_k), 0);
+ while (w.size() > k) {
+ w.pop();
+ }
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(tmp.size());
+ for (const auto & hit : tmp) {
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ }
+ return result;
+}
double