diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-02-25 14:40:02 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-02-25 14:47:36 +0000 |
commit | 16c477ad557d556fe4d63c871025f10b18aba84d (patch) | |
tree | decad336db7e7513962978c9cc660ec407ded213 /eval | |
parent | 44fef3325d3e9bfa673d71b87721f0979f8404c8 (diff) |
keep more code common
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/ann/CMakeLists.txt | 2 | ||||
-rw-r--r-- | eval/src/tests/ann/extended-hnsw.cpp | 686 | ||||
-rw-r--r-- | eval/src/tests/ann/xp-hnswlike-nns.cpp | 527 |
3 files changed, 427 insertions, 788 deletions
diff --git a/eval/src/tests/ann/CMakeLists.txt b/eval/src/tests/ann/CMakeLists.txt index 34babf1412f..0ba38994c01 100644 --- a/eval/src/tests/ann/CMakeLists.txt +++ b/eval/src/tests/ann/CMakeLists.txt @@ -14,7 +14,7 @@ vespa_add_executable(eval_gist_benchmark_app SOURCES gist_benchmark.cpp xp-annoy-nns.cpp - xp-hnswlike-nns.cpp + extended-hnsw.cpp xp-lsh-nns.cpp DEPENDS vespaeval diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp index 42f3a10b389..fbc4bedec05 100644 --- a/eval/src/tests/ann/extended-hnsw.cpp +++ b/eval/src/tests/ann/extended-hnsw.cpp @@ -1,29 +1,6 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <algorithm> -#include <assert.h> -#include <queue> -#include <cinttypes> -#include "std-random.h" -#include "nns.h" - -/* - Todo: - - measure effect of: - 1) removing leftover backlinks during "shrink" operation - 2) refilling to low-watermark after 1) happens - 3) refilling to mid-watermark after 1) happens - 4) adding then removing 20% extra documents - 5) removing 20% first-added documents - 6) removing first-added documents while inserting new ones - - 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100 - 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100 - 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids - - 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k) - */ +#include "hnsw-like.h" static size_t distcalls_simple; static size_t distcalls_search_layer; @@ -38,471 +15,300 @@ static size_t disconnected_for_symmetry; static size_t select_n_full; static size_t select_n_partial; -struct LinkList : std::vector<uint32_t> -{ - bool has_link_to(uint32_t id) const { - auto iter = std::find(begin(), end(), id); - return (iter != end()); - } - void remove_link(uint32_t id) { - uint32_t last = back(); - for (iterator iter = begin(); iter != end(); ++iter) { - if (*iter == id) { - *iter = last; - pop_back(); - return; - } - } - fprintf(stderr, "BAD missing link to remove: %u\n", id); - abort(); - } -}; - -struct Node { - std::vector<LinkList> _links; - Node(uint32_t , uint32_t numLevels, uint32_t M) - : _links(numLevels) - { - for (uint32_t i = 0; i < _links.size(); ++i) { - _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); - } - } -}; - -struct VisitedSet -{ - using Mark = unsigned short; - Mark *ptr; - Mark curval; - size_t sz; - VisitedSet(const VisitedSet &) = delete; - VisitedSet& operator=(const VisitedSet &) = delete; - explicit VisitedSet(size_t size) { - ptr = (Mark *)malloc(size * sizeof(Mark)); - curval = -1; - sz = size; - clear(); - } - void clear() { - ++curval; - if (curval == 0) { - memset(ptr, 0, sz * sizeof(Mark)); - ++curval; - } - } - ~VisitedSet() { free(ptr); } - void mark(size_t id) { ptr[id] = curval; } - bool isMarked(size_t id) const { return ptr[id] == curval; } -}; -struct VisitedSetPool +HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(200), + _levelMultiplier(1.0 / log(1.0 * _M)), + _rndGen(), + _ops_counter(0) { - std::unique_ptr<VisitedSet> lastUsed; - VisitedSetPool() { - lastUsed = std::make_unique<VisitedSet>(250); - } - ~VisitedSetPool() {} - VisitedSet &get(size_t size) { - if (size > lastUsed->sz) { - lastUsed = std::make_unique<VisitedSet>(size*2); - } else { - lastUsed->clear(); - } - return *lastUsed; - } -}; - -struct HnswHit { - double dist; - uint32_t docid; - HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} -}; - -struct GreaterDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (rhs.dist < lhs.dist); - } -}; -struct LesserDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (lhs.dist < rhs.dist); - } -}; - -using NearestList = std::vector<HnswHit>; - -struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist> -{ -}; - -struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist> -{ - NearestList steal() { - NearestList result; - c.swap(result); - return result; - } - const NearestList& peek() const { return c; } -}; - -class HnswLikeNns : public NNS<float> -{ -private: - std::vector<Node> _nodes; - uint32_t _entryId; - int _entryLevel; - uint32_t _M; - uint32_t _efConstruction; - double _levelMultiplier; - RndGen _rndGen; - VisitedSetPool _visitedSetPool; - size_t _ops_counter; - - double distance(Vector v, uint32_t id) const; - - double distance(uint32_t a, uint32_t b) const { - Vector v = _dva.get(a); - return distance(v, b); - } - - int randomLevel() { - double unif = _rndGen.nextUniform(); - double r = -log(1.0-unif) * _levelMultiplier; - return (int) r; - } - - uint32_t count_reachable() const; - void dumpStats() const; - -public: - HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva) - : NNS(numDims, dva), - _nodes(), - _entryId(0), - _entryLevel(-1), - _M(16), - _efConstruction(200), - _levelMultiplier(1.0 / log(1.0 * _M)), - _rndGen(), - _ops_counter(0) - { - } - - ~HnswLikeNns() { dumpStats(); } - - LinkList& getLinkList(uint32_t docid, uint32_t level) { - // assert(docid < _nodes.size()); - // assert(level < _nodes[docid]._links.size()); - return _nodes[docid]._links[level]; - } - - const LinkList& getLinkList(uint32_t docid, uint32_t level) const { - return _nodes[docid]._links[level]; - } +} - // simple greedy search - HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { - bool keepGoing = true; - while (keepGoing) { - keepGoing = false; - const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); - for (uint32_t n_id : neighbors) { - double dist = distance(vector, n_id); - ++distcalls_simple; - if (dist < curPoint.dist) { - curPoint = HnswHit(n_id, SqDist(dist)); - keepGoing = true; - } +// simple greedy search +HnswHit +HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + ++distcalls_simple; + if (dist < curPoint.dist) { + curPoint = HnswHit(n_id, SqDist(dist)); + keepGoing = true; } } - return curPoint; } + return curPoint; +} - void search_layer(Vector vector, FurthestPriQ &w, - VisitedSet &visited, - uint32_t ef, uint32_t searchLevel); - - void search_layer_with_filter(Vector vector, FurthestPriQ &w, - VisitedSet &visited, - uint32_t ef, uint32_t searchLevel, - const BitVector &blacklist); - bool haveCloserDistance(HnswHit e, const LinkList &r) const { - for (uint32_t prevId : r) { - double dist = distance(e.docid, prevId); - ++distcalls_heuristic; - if (dist < e.dist) return true; - } - return false; +bool +HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + ++distcalls_heuristic; + if (dist < e.dist) return true; } + return false; +} - LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; - - LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; - - void addDoc(uint32_t docid) override { - Vector vector = _dva.get(docid); - for (uint32_t id = _nodes.size(); id <= docid; ++id) { - _nodes.emplace_back(id, 0, _M); - } - int level = randomLevel(); - assert(_nodes[docid]._links.size() == 0); - _nodes[docid] = Node(docid, level+1, _M); - if (_entryLevel < 0) { - _entryId = docid; - _entryLevel = level; - track_ops(); - return; - } - int searchLevel = _entryLevel; - VisitedSet &visited = _visitedSetPool.get(_nodes.size()); - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); +void +HnswLikeNns::addDoc(uint32_t docid) { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + track_ops(); + return; + } + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); #undef MULTI_ENTRY_I #ifdef MULTI_ENTRY_I - FurthestPriQ w; - w.push(entryPoint); - while (searchLevel > level) { - search_layer(vector, w, visited, 5 * _M, searchLevel); - --searchLevel; - } + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > level) { + search_layer(vector, w, visited, 5 * _M, searchLevel); + --searchLevel; + } #else - while (searchLevel > level) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - FurthestPriQ w; - w.push(entryPoint); + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); #endif - searchLevel = std::min(level, _entryLevel); - while (searchLevel >= 0) { - search_layer(vector, w, visited, _efConstruction, searchLevel); - LinkList neighbors = select_neighbors(w.peek(), _M); - connect_new_node(docid, neighbors, searchLevel); - each_shrink_ifneeded(neighbors, searchLevel); - --searchLevel; - } - if (level > _entryLevel) { - _entryLevel = level; - _entryId = docid; - } - track_ops(); + searchLevel = std::min(level, _entryLevel); + while (searchLevel >= 0) { + search_layer(vector, w, visited, _efConstruction, searchLevel); + LinkList neighbors = select_neighbors(w.peek(), _M); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; } - - void track_ops() { - _ops_counter++; - if ((_ops_counter % 10000) == 0) { - double div = _ops_counter; - fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); - fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); - fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); - fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); - fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); - fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); - fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); - fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); - fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); - fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); - fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); - fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); - } - } - - void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { - LinkList &links = getLinkList(from_id, level); - links.remove_link(remove_id); + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; } + track_ops(); +} + +void +HnswLikeNns::track_ops() { + _ops_counter++; + if ((_ops_counter % 10000) == 0) { + double div = _ops_counter; + fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); + fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); + } +} -#undef SIMPLE_REFILL #ifdef SIMPLE_REFILL - void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - if (my_links.size() * 2 < _M) { - const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); - ++refill_needed_calls; - for (uint32_t repl_id : replacements) { - if (repl_id == my_id) continue; - if (my_links.has_link_to(repl_id)) continue; - LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() >= maxLinks) continue; - other_links.push_back(my_id); - my_links.push_back(repl_id); - if (my_links.size() >= _M) return; - } - } - } -#else - void refill_all(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() * 2 < _M) { const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); - NearestPriQ w; + ++refill_needed_calls; for (uint32_t repl_id : replacements) { if (repl_id == my_id) continue; if (my_links.has_link_to(repl_id)) continue; - const LinkList &other_links = getLinkList(repl_id, level); + LinkList &other_links = getLinkList(repl_id, level); if (other_links.size() >= maxLinks) continue; - double dist = distance(my_id, repl_id); - ++distcalls_refill; - w.emplace(repl_id, SqDist(dist)); - } - while (! w.empty()) { - HnswHit e = w.top(); - w.pop(); - if (haveCloserDistance(e, my_links)) continue; - LinkList &other_links = getLinkList(e.docid, level); - my_links.push_back(e.docid); other_links.push_back(my_id); - if (my_links.size() == _M) break; + my_links.push_back(repl_id); + if (my_links.size() >= _M) return; } } - void refill_one(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - NearestPriQ w; - for (uint32_t repl_id : replacements) { - if (repl_id == my_id) continue; - if (my_links.has_link_to(repl_id)) continue; - LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() >= _M) continue; - double dist = distance(my_id, repl_id); - ++distcalls_refill; - w.emplace(repl_id, SqDist(dist)); - } - while (! w.empty()) { - HnswHit e = w.top(); - w.pop(); - if (haveCloserDistance(e, my_links)) continue; - LinkList &other_links = getLinkList(e.docid, level); - my_links.push_back(e.docid); - other_links.push_back(my_id); - return; - } +} +#endif + +#define REFILL_ALL +#ifdef REFILL_ALL +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() >= _M) return; + ++refill_needed_calls; + const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + const LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= maxLinks) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); } - void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - if (my_links.size() < _M) { - ++refill_needed_calls; - refill_all(my_id, replacements, level); - } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + if (my_links.size() == _M) break; } +} #endif - void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); +#ifdef REFILL_ONE +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() >= _M) return; + ++refill_needed_calls; + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= _M) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + return; + } +} +#endif - void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { - LinkList &links = getLinkList(shrink_id, level); - NearestList distances; - for (uint32_t n_id : links) { - double n_dist = distance(shrink_id, n_id); - ++distcalls_shrink; - distances.emplace_back(n_id, SqDist(n_dist)); - } - LinkList lostLinks; - LinkList oldLinks = links; - links = remove_weakest(distances, maxLinks, lostLinks); +void +HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { + LinkList &links = getLinkList(shrink_id, level); + NearestList distances; + for (uint32_t n_id : links) { + double n_dist = distance(shrink_id, n_id); + ++distcalls_shrink; + distances.emplace_back(n_id, SqDist(n_dist)); + } + LinkList lostLinks; + LinkList oldLinks = links; + links = remove_weakest(distances, maxLinks, lostLinks); #define KEEP_SYM #ifdef KEEP_SYM - for (uint32_t lost_id : lostLinks) { - ++disconnected_for_symmetry; - remove_link_from(lost_id, shrink_id, level); - } + for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; + remove_link_from(lost_id, shrink_id, level); + } #define DO_REFILL_AFTER_KEEP_SYM #ifdef DO_REFILL_AFTER_KEEP_SYM - for (uint32_t lost_id : lostLinks) { -#ifdef SIMPLE_REFILL - refill_ifneeded(lost_id, oldLinks, level); -#else - refill_all(lost_id, oldLinks, level); -#endif - } + for (uint32_t lost_id : lostLinks) { + refill_ifneeded(lost_id, oldLinks, level); + } #endif #endif - } +} - void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); - void mutually_reconnect(LinkList cluster, int level) { - while (! cluster.empty()) { - uint32_t n_id = cluster.back(); - cluster.pop_back(); -#ifdef SIMPLE_REFILL - refill_ifneeded(n_id, cluster, level); -#else - refill_all(n_id, cluster, level); -#endif - } - } - - void removeDoc(uint32_t docid) override { - Node &node = _nodes[docid]; - bool need_new_entrypoint = (docid == _entryId); - for (int level = node._links.size(); level-- > 0; ) { - LinkList my_links; - my_links.swap(node._links[level]); - for (uint32_t n_id : my_links) { - if (need_new_entrypoint) { - _entryId = n_id; - _entryLevel = level; - need_new_entrypoint = false; - } - remove_link_from(n_id, docid, level); +void +HnswLikeNns::removeDoc(uint32_t docid) { + Node &node = _nodes[docid]; + bool need_new_entrypoint = (docid == _entryId); + for (int level = node._links.size(); level-- > 0; ) { + LinkList my_links; + my_links.swap(node._links[level]); + for (uint32_t n_id : my_links) { + if (need_new_entrypoint) { + _entryId = n_id; + _entryLevel = level; + need_new_entrypoint = false; } - mutually_reconnect(my_links, level); - } - node = Node(docid, 0, _M); - if (need_new_entrypoint) { - _entryLevel = -1; - _entryId = 0; - for (uint32_t i = 0; i < _nodes.size(); ++i) { - if (_nodes[i]._links.size() > 0) { - _entryId = i; - _entryLevel = _nodes[i]._links.size() - 1; - break; - } + remove_link_from(n_id, docid, level); + } + while (! my_links.empty()) { + uint32_t n_id = my_links.back(); + my_links.pop_back(); + refill_ifneeded(n_id, my_links, level); + } + } + node = Node(docid, 0, _M); + if (need_new_entrypoint) { + _entryLevel = -1; + _entryId = 0; + for (uint32_t i = 0; i < _nodes.size(); ++i) { + if (_nodes[i]._links.size() > 0) { + _entryId = i; + _entryLevel = _nodes[i]._links.size() - 1; + break; } } - track_ops(); } + track_ops(); +} - std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override { - std::vector<NnsHit> result; - if (_entryLevel < 0) return result; - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); - int searchLevel = _entryLevel; - VisitedSet &visited = _visitedSetPool.get(_nodes.size()); +std::vector<NnsHit> +HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) { + std::vector<NnsHit> result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); #undef MULTI_ENTRY_S #ifdef MULTI_ENTRY_S - FurthestPriQ w; - w.push(entryPoint); - while (searchLevel > 0) { - search_layer(vector, w, visited, std::min(k, search_k), searchLevel); - --searchLevel; - } + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > 0) { + search_layer(vector, w, visited, std::min(k, search_k), searchLevel); + --searchLevel; + } #else - while (searchLevel > 0) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - FurthestPriQ w; - w.push(entryPoint); + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); #endif - search_layer(vector, w, visited, std::max(k, search_k), 0); - while (w.size() > k) { - w.pop(); - } - NearestList tmp = w.steal(); - std::sort(tmp.begin(), tmp.end(), LesserDist()); - result.reserve(tmp.size()); - for (const auto & hit : tmp) { - result.emplace_back(hit.docid, SqDist(hit.dist)); - } - return result; + search_layer(vector, w, visited, std::max(k, search_k), 0); + while (w.size() > k) { + w.pop(); } - - std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; -}; + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(tmp.size()); + for (const auto & hit : tmp) { + result.emplace_back(hit.docid, SqDist(hit.dist)); + } + return result; +} double diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp index 90fc0fe2e92..b7cae9f731c 100644 --- a/eval/src/tests/ann/xp-hnswlike-nns.cpp +++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp @@ -1,11 +1,6 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <algorithm> -#include <assert.h> -#include <queue> -#include <cinttypes> -#include "std-random.h" -#include "nns.h" +#include "hnsw-like.h" /* Todo: @@ -38,378 +33,216 @@ static size_t disconnected_for_symmetry; static size_t select_n_full; static size_t select_n_partial; -struct LinkList : std::vector<uint32_t> -{ - bool has_link_to(uint32_t id) const { - auto iter = std::find(begin(), end(), id); - return (iter != end()); - } - void remove_link(uint32_t id) { - uint32_t last = back(); - for (iterator iter = begin(); iter != end(); ++iter) { - if (*iter == id) { - *iter = last; - pop_back(); - return; - } - } - fprintf(stderr, "BAD missing link to remove: %u\n", id); - abort(); - } -}; -struct Node { - std::vector<LinkList> _links; - Node(uint32_t , uint32_t numLevels, uint32_t M) - : _links(numLevels) - { - for (uint32_t i = 0; i < _links.size(); ++i) { - _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); - } - } -}; -struct VisitedSet +HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(200), + _levelMultiplier(1.0 / log(1.0 * _M)), + _rndGen(), + _ops_counter(0) { - using Mark = unsigned short; - Mark *ptr; - Mark curval; - size_t sz; - VisitedSet(const VisitedSet &) = delete; - VisitedSet& operator=(const VisitedSet &) = delete; - explicit VisitedSet(size_t size) { - ptr = (Mark *)malloc(size * sizeof(Mark)); - curval = -1; - sz = size; - clear(); - } - void clear() { - ++curval; - if (curval == 0) { - memset(ptr, 0, sz * sizeof(Mark)); - ++curval; - } - } - ~VisitedSet() { free(ptr); } - void mark(size_t id) { ptr[id] = curval; } - bool isMarked(size_t id) const { return ptr[id] == curval; } -}; +} -struct VisitedSetPool -{ - std::unique_ptr<VisitedSet> lastUsed; - VisitedSetPool() { - lastUsed = std::make_unique<VisitedSet>(250); - } - ~VisitedSetPool() {} - VisitedSet &get(size_t size) { - if (size > lastUsed->sz) { - lastUsed = std::make_unique<VisitedSet>(size*2); - } else { - lastUsed->clear(); +// simple greedy search +HnswHit +HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + ++distcalls_simple; + if (dist < curPoint.dist) { + curPoint = HnswHit(n_id, SqDist(dist)); + keepGoing = true; + } } - return *lastUsed; - } -}; - -struct HnswHit { - double dist; - uint32_t docid; - HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} -}; - -struct GreaterDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (rhs.dist < lhs.dist); - } -}; -struct LesserDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (lhs.dist < rhs.dist); - } -}; - -using NearestList = std::vector<HnswHit>; - -struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist> -{ -}; - -struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist> -{ - NearestList steal() { - NearestList result; - c.swap(result); - return result; - } - const NearestList& peek() const { return c; } -}; - -class HnswLikeNns : public NNS<float> -{ -private: - std::vector<Node> _nodes; - uint32_t _entryId; - int _entryLevel; - uint32_t _M; - uint32_t _efConstruction; - double _levelMultiplier; - RndGen _rndGen; - VisitedSetPool _visitedSetPool; - size_t _ops_counter; - - double distance(Vector v, uint32_t id) const; - - double distance(uint32_t a, uint32_t b) const { - Vector v = _dva.get(a); - return distance(v, b); - } - - int randomLevel() { - double unif = _rndGen.nextUniform(); - double r = -log(1.0-unif) * _levelMultiplier; - return (int) r; - } - - uint32_t count_reachable() const; - void dumpStats() const; - -public: - HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva) - : NNS(numDims, dva), - _nodes(), - _entryId(0), - _entryLevel(-1), - _M(16), - _efConstruction(200), - _levelMultiplier(1.0 / log(1.0 * _M)), - _rndGen(), - _ops_counter(0) - { } + return curPoint; +} - ~HnswLikeNns() { dumpStats(); } - - LinkList& getLinkList(uint32_t docid, uint32_t level) { - // assert(docid < _nodes.size()); - // assert(level < _nodes[docid]._links.size()); - return _nodes[docid]._links[level]; +bool +HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + ++distcalls_heuristic; + if (dist < e.dist) return true; } + return false; +} - const LinkList& getLinkList(uint32_t docid, uint32_t level) const { - return _nodes[docid]._links[level]; +void +HnswLikeNns::addDoc(uint32_t docid) { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + track_ops(); + return; } - - // simple greedy search - HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { - bool keepGoing = true; - while (keepGoing) { - keepGoing = false; - const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); - for (uint32_t n_id : neighbors) { - double dist = distance(vector, n_id); - ++distcalls_simple; - if (dist < curPoint.dist) { - curPoint = HnswHit(n_id, SqDist(dist)); - keepGoing = true; - } - } - } - return curPoint; + int searchLevel = _entryLevel; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; } - - void search_layer(Vector vector, FurthestPriQ &w, - uint32_t ef, uint32_t searchLevel); - - void search_layer_with_filter(Vector vector, FurthestPriQ &w, - uint32_t ef, uint32_t searchLevel, - const BitVector &blacklist); - - bool haveCloserDistance(HnswHit e, const LinkList &r) const { - for (uint32_t prevId : r) { - double dist = distance(e.docid, prevId); - ++distcalls_heuristic; - if (dist < e.dist) return true; - } - return false; + searchLevel = std::min(level, _entryLevel); + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel >= 0) { + search_layer(vector, w, _efConstruction, searchLevel); + LinkList neighbors = select_neighbors(w.peek(), _M); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; } - - LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; - - LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; - - void addDoc(uint32_t docid) override { - Vector vector = _dva.get(docid); - for (uint32_t id = _nodes.size(); id <= docid; ++id) { - _nodes.emplace_back(id, 0, _M); - } - int level = randomLevel(); - assert(_nodes[docid]._links.size() == 0); - _nodes[docid] = Node(docid, level+1, _M); - if (_entryLevel < 0) { - _entryId = docid; - _entryLevel = level; - track_ops(); - return; - } - int searchLevel = _entryLevel; - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); - while (searchLevel > level) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - searchLevel = std::min(level, _entryLevel); - FurthestPriQ w; - w.push(entryPoint); - while (searchLevel >= 0) { - search_layer(vector, w, _efConstruction, searchLevel); - LinkList neighbors = select_neighbors(w.peek(), _M); - connect_new_node(docid, neighbors, searchLevel); - each_shrink_ifneeded(neighbors, searchLevel); - --searchLevel; - } - if (level > _entryLevel) { - _entryLevel = level; - _entryId = docid; - } - track_ops(); + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; } + track_ops(); +} - void track_ops() { - _ops_counter++; - if ((_ops_counter % 10000) == 0) { - double div = _ops_counter; - fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); - fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); - fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); - fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); - fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); - fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); - fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); - fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); - fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); - fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); - fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); - fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); - } - } - - void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { - LinkList &links = getLinkList(from_id, level); - links.remove_link(remove_id); - } +void +HnswLikeNns::track_ops() { + _ops_counter++; + if ((_ops_counter % 10000) == 0) { + double div = _ops_counter; + fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); + fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); + } +} - void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - if (my_links.size() < 8) { - ++refill_needed_calls; - for (uint32_t repl_id : replacements) { - if (repl_id == my_id) continue; - if (my_links.has_link_to(repl_id)) continue; - LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() + 1 >= _M) continue; - other_links.push_back(my_id); - my_links.push_back(repl_id); - if (my_links.size() >= _M) return; - } +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() < 8) { + ++refill_needed_calls; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() + 1 >= _M) continue; + other_links.push_back(my_id); + my_links.push_back(repl_id); + if (my_links.size() >= _M) return; } } +} - void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); - - void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { - LinkList &links = getLinkList(shrink_id, level); - NearestList distances; - for (uint32_t n_id : links) { - double n_dist = distance(shrink_id, n_id); - ++distcalls_shrink; - distances.emplace_back(n_id, SqDist(n_dist)); - } - LinkList lostLinks; - LinkList oldLinks = links; - links = remove_weakest(distances, maxLinks, lostLinks); +void +HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { + LinkList &links = getLinkList(shrink_id, level); + NearestList distances; + for (uint32_t n_id : links) { + double n_dist = distance(shrink_id, n_id); + ++distcalls_shrink; + distances.emplace_back(n_id, SqDist(n_dist)); + } + LinkList lostLinks; + LinkList oldLinks = links; + links = remove_weakest(distances, maxLinks, lostLinks); #define KEEP_SYM #ifdef KEEP_SYM - for (uint32_t lost_id : lostLinks) { - ++disconnected_for_symmetry; - remove_link_from(lost_id, shrink_id, level); - } + for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; + remove_link_from(lost_id, shrink_id, level); + } #define DO_REFILL_AFTER_KEEP_SYM #ifdef DO_REFILL_AFTER_KEEP_SYM - for (uint32_t lost_id : lostLinks) { - refill_ifneeded(lost_id, oldLinks, level); - } + for (uint32_t lost_id : lostLinks) { + refill_ifneeded(lost_id, oldLinks, level); + } #endif #endif - } - - void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); +} - void removeDoc(uint32_t docid) override { - Node &node = _nodes[docid]; - bool need_new_entrypoint = (docid == _entryId); - for (int level = node._links.size(); level-- > 0; ) { - LinkList my_links; - my_links.swap(node._links[level]); - for (uint32_t n_id : my_links) { - if (need_new_entrypoint) { - _entryId = n_id; - _entryLevel = level; - need_new_entrypoint = false; - } - remove_link_from(n_id, docid, level); - } - while (! my_links.empty()) { - uint32_t n_id = my_links.back(); - my_links.pop_back(); - refill_ifneeded(n_id, my_links, level); +void +HnswLikeNns::removeDoc(uint32_t docid) { + Node &node = _nodes[docid]; + bool need_new_entrypoint = (docid == _entryId); + for (int level = node._links.size(); level-- > 0; ) { + LinkList my_links; + my_links.swap(node._links[level]); + for (uint32_t n_id : my_links) { + if (need_new_entrypoint) { + _entryId = n_id; + _entryLevel = level; + need_new_entrypoint = false; } - } - node = Node(docid, 0, _M); - if (need_new_entrypoint) { - _entryLevel = -1; - _entryId = 0; - for (uint32_t i = 0; i < _nodes.size(); ++i) { - if (_nodes[i]._links.size() > 0) { - _entryId = i; - _entryLevel = _nodes[i]._links.size() - 1; - break; - } + remove_link_from(n_id, docid, level); + } + while (! my_links.empty()) { + uint32_t n_id = my_links.back(); + my_links.pop_back(); + refill_ifneeded(n_id, my_links, level); + } + } + node = Node(docid, 0, _M); + if (need_new_entrypoint) { + _entryLevel = -1; + _entryId = 0; + for (uint32_t i = 0; i < _nodes.size(); ++i) { + if (_nodes[i]._links.size() > 0) { + _entryId = i; + _entryLevel = _nodes[i]._links.size() - 1; + break; } } - track_ops(); } + track_ops(); +} - std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override { - std::vector<NnsHit> result; - if (_entryLevel < 0) return result; - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); - int searchLevel = _entryLevel; - while (searchLevel > 0) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - FurthestPriQ w; - w.push(entryPoint); - search_layer(vector, w, std::max(k, search_k), 0); - while (w.size() > k) { - w.pop(); - } - NearestList tmp = w.steal(); - std::sort(tmp.begin(), tmp.end(), LesserDist()); - result.reserve(tmp.size()); - for (const auto & hit : tmp) { - result.emplace_back(hit.docid, SqDist(hit.dist)); - } - return result; +std::vector<NnsHit> +HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) { + std::vector<NnsHit> result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; } - - std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; -}; + FurthestPriQ w; + w.push(entryPoint); + search_layer(vector, w, std::max(k, search_k), 0); + while (w.size() > k) { + w.pop(); + } + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(tmp.size()); + for (const auto & hit : tmp) { + result.emplace_back(hit.docid, SqDist(hit.dist)); + } + return result; +} double |