From ffa2293de302d99051f7fc97d29c4dc606f045f1 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Mon, 24 Feb 2020 09:42:02 +0000 Subject: experimental HNSW with various extensions --- eval/src/tests/ann/extended-hnsw.cpp | 830 +++++++++++++++++++++++++++++++++++ 1 file changed, 830 insertions(+) create mode 100644 eval/src/tests/ann/extended-hnsw.cpp diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp new file mode 100644 index 00000000000..42f3a10b389 --- /dev/null +++ b/eval/src/tests/ann/extended-hnsw.cpp @@ -0,0 +1,830 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include +#include "std-random.h" +#include "nns.h" + +/* + Todo: + + measure effect of: + 1) removing leftover backlinks during "shrink" operation + 2) refilling to low-watermark after 1) happens + 3) refilling to mid-watermark after 1) happens + 4) adding then removing 20% extra documents + 5) removing 20% first-added documents + 6) removing first-added documents while inserting new ones + + 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100 + 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100 + 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids + + 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k) + */ + +static size_t distcalls_simple; +static size_t distcalls_search_layer; +static size_t distcalls_other; +static size_t distcalls_heuristic; +static size_t distcalls_shrink; +static size_t distcalls_refill; +static size_t refill_needed_calls; +static size_t shrink_needed_calls; +static size_t disconnected_weak_links; +static size_t disconnected_for_symmetry; +static size_t select_n_full; +static size_t select_n_partial; + +struct LinkList : std::vector +{ + bool has_link_to(uint32_t id) const { + auto iter = std::find(begin(), end(), id); + return (iter != end()); + } + void remove_link(uint32_t id) { + uint32_t last = back(); + for (iterator iter = begin(); iter != end(); ++iter) { + if (*iter == id) { + *iter = last; + pop_back(); + return; + } + } + fprintf(stderr, "BAD missing link to remove: %u\n", id); + abort(); + } +}; + +struct Node { + std::vector _links; + Node(uint32_t , uint32_t numLevels, uint32_t M) + : _links(numLevels) + { + for (uint32_t i = 0; i < _links.size(); ++i) { + _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); + } + } +}; + +struct VisitedSet +{ + using Mark = unsigned short; + Mark *ptr; + Mark curval; + size_t sz; + VisitedSet(const VisitedSet &) = delete; + VisitedSet& operator=(const VisitedSet &) = delete; + explicit VisitedSet(size_t size) { + ptr = (Mark *)malloc(size * sizeof(Mark)); + curval = -1; + sz = size; + clear(); + } + void clear() { + ++curval; + if (curval == 0) { + memset(ptr, 0, sz * sizeof(Mark)); + ++curval; + } + } + ~VisitedSet() { free(ptr); } + void mark(size_t id) { ptr[id] = curval; } + bool isMarked(size_t id) const { return ptr[id] == curval; } +}; + +struct VisitedSetPool +{ + std::unique_ptr lastUsed; + VisitedSetPool() { + lastUsed = std::make_unique(250); + } + ~VisitedSetPool() {} + VisitedSet &get(size_t size) { + if (size > lastUsed->sz) { + lastUsed = std::make_unique(size*2); + } else { + lastUsed->clear(); + } + return *lastUsed; + } +}; + +struct HnswHit { + double dist; + uint32_t docid; + HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} +}; + +struct GreaterDist { + bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { + return (rhs.dist < lhs.dist); + } +}; +struct LesserDist { + bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { + return (lhs.dist < rhs.dist); + } +}; + +using NearestList = std::vector; + +struct NearestPriQ : std::priority_queue +{ +}; + +struct FurthestPriQ : std::priority_queue +{ + NearestList steal() { + NearestList result; + c.swap(result); + return result; + } + const NearestList& peek() const { return c; } +}; + +class HnswLikeNns : public NNS +{ +private: + std::vector _nodes; + uint32_t _entryId; + int _entryLevel; + uint32_t _M; + uint32_t _efConstruction; + double _levelMultiplier; + RndGen _rndGen; + VisitedSetPool _visitedSetPool; + size_t _ops_counter; + + double distance(Vector v, uint32_t id) const; + + double distance(uint32_t a, uint32_t b) const { + Vector v = _dva.get(a); + return distance(v, b); + } + + int randomLevel() { + double unif = _rndGen.nextUniform(); + double r = -log(1.0-unif) * _levelMultiplier; + return (int) r; + } + + uint32_t count_reachable() const; + void dumpStats() const; + +public: + HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(200), + _levelMultiplier(1.0 / log(1.0 * _M)), + _rndGen(), + _ops_counter(0) + { + } + + ~HnswLikeNns() { dumpStats(); } + + LinkList& getLinkList(uint32_t docid, uint32_t level) { + // assert(docid < _nodes.size()); + // assert(level < _nodes[docid]._links.size()); + return _nodes[docid]._links[level]; + } + + const LinkList& getLinkList(uint32_t docid, uint32_t level) const { + return _nodes[docid]._links[level]; + } + + // simple greedy search + HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + ++distcalls_simple; + if (dist < curPoint.dist) { + curPoint = HnswHit(n_id, SqDist(dist)); + keepGoing = true; + } + } + } + return curPoint; + } + + void search_layer(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel); + + void search_layer_with_filter(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist); + + bool haveCloserDistance(HnswHit e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + ++distcalls_heuristic; + if (dist < e.dist) return true; + } + return false; + } + + LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; + + LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; + + void addDoc(uint32_t docid) override { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + track_ops(); + return; + } + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); +#undef MULTI_ENTRY_I +#ifdef MULTI_ENTRY_I + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > level) { + search_layer(vector, w, visited, 5 * _M, searchLevel); + --searchLevel; + } +#else + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); +#endif + searchLevel = std::min(level, _entryLevel); + while (searchLevel >= 0) { + search_layer(vector, w, visited, _efConstruction, searchLevel); + LinkList neighbors = select_neighbors(w.peek(), _M); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; + } + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; + } + track_ops(); + } + + void track_ops() { + _ops_counter++; + if ((_ops_counter % 10000) == 0) { + double div = _ops_counter; + fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); + fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); + } + } + + void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { + LinkList &links = getLinkList(from_id, level); + links.remove_link(remove_id); + } + +#undef SIMPLE_REFILL +#ifdef SIMPLE_REFILL + void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() * 2 < _M) { + const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + ++refill_needed_calls; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= maxLinks) continue; + other_links.push_back(my_id); + my_links.push_back(repl_id); + if (my_links.size() >= _M) return; + } + } + } +#else + void refill_all(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + const LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= maxLinks) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + if (my_links.size() == _M) break; + } + } + void refill_one(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= _M) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + return; + } + } + void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() < _M) { + ++refill_needed_calls; + refill_all(my_id, replacements, level); + } + } +#endif + + void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); + + void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { + LinkList &links = getLinkList(shrink_id, level); + NearestList distances; + for (uint32_t n_id : links) { + double n_dist = distance(shrink_id, n_id); + ++distcalls_shrink; + distances.emplace_back(n_id, SqDist(n_dist)); + } + LinkList lostLinks; + LinkList oldLinks = links; + links = remove_weakest(distances, maxLinks, lostLinks); +#define KEEP_SYM +#ifdef KEEP_SYM + for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; + remove_link_from(lost_id, shrink_id, level); + } +#define DO_REFILL_AFTER_KEEP_SYM +#ifdef DO_REFILL_AFTER_KEEP_SYM + for (uint32_t lost_id : lostLinks) { +#ifdef SIMPLE_REFILL + refill_ifneeded(lost_id, oldLinks, level); +#else + refill_all(lost_id, oldLinks, level); +#endif + } +#endif +#endif + } + + void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); + + void mutually_reconnect(LinkList cluster, int level) { + while (! cluster.empty()) { + uint32_t n_id = cluster.back(); + cluster.pop_back(); +#ifdef SIMPLE_REFILL + refill_ifneeded(n_id, cluster, level); +#else + refill_all(n_id, cluster, level); +#endif + } + } + + void removeDoc(uint32_t docid) override { + Node &node = _nodes[docid]; + bool need_new_entrypoint = (docid == _entryId); + for (int level = node._links.size(); level-- > 0; ) { + LinkList my_links; + my_links.swap(node._links[level]); + for (uint32_t n_id : my_links) { + if (need_new_entrypoint) { + _entryId = n_id; + _entryLevel = level; + need_new_entrypoint = false; + } + remove_link_from(n_id, docid, level); + } + mutually_reconnect(my_links, level); + } + node = Node(docid, 0, _M); + if (need_new_entrypoint) { + _entryLevel = -1; + _entryId = 0; + for (uint32_t i = 0; i < _nodes.size(); ++i) { + if (_nodes[i]._links.size() > 0) { + _entryId = i; + _entryLevel = _nodes[i]._links.size() - 1; + break; + } + } + } + track_ops(); + } + + std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override { + std::vector result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); +#undef MULTI_ENTRY_S +#ifdef MULTI_ENTRY_S + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > 0) { + search_layer(vector, w, visited, std::min(k, search_k), searchLevel); + --searchLevel; + } +#else + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); +#endif + search_layer(vector, w, visited, std::max(k, search_k), 0); + while (w.size() > k) { + w.pop(); + } + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(tmp.size()); + for (const auto & hit : tmp) { + result.emplace_back(hit.docid, SqDist(hit.dist)); + } + return result; + } + + std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; +}; + + +double +HnswLikeNns::distance(Vector v, uint32_t b) const +{ + Vector w = _dva.get(b); + return l2distCalc.l2sq_dist(v, w); +} + +std::vector +HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + std::vector result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); +#ifdef MULTI_ENTRY_S + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > 0) { + search_layer(vector, w, visited, std::min(k, search_k), searchLevel); + --searchLevel; + } +#else + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); +#endif + search_layer_with_filter(vector, w, visited, std::max(k, search_k), 0, blacklist); + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(std::min((size_t)k, tmp.size())); + for (const auto & hit : tmp) { + if (blacklist.isSet(hit.docid)) continue; + result.emplace_back(hit.docid, SqDist(hit.dist)); + if (result.size() == k) break; + } + return result; +} + +void +HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) { + uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + for (uint32_t old_id : neighbors) { + LinkList &oldLinks = getLinkList(old_id, level); + if (oldLinks.size() > maxLinks) { + ++shrink_needed_calls; + shrink_links(old_id, maxLinks, level); + } + } +} + +void +HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel) +{ + NearestPriQ candidates; + + for (const HnswHit & entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + } + double limd = std::numeric_limits::max(); + while (! candidates.empty()) { + HnswHit cand = candidates.top(); + if (cand.dist > limd) { + break; + } + candidates.pop(); + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + ++distcalls_search_layer; + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } + return; +} + +void +HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist) +{ + NearestPriQ candidates; + + for (const HnswHit & entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + if (blacklist.isSet(entry.docid)) ++ef; + } + double limd = std::numeric_limits::max(); + while (! candidates.empty()) { + HnswHit cand = candidates.top(); + if (cand.dist > limd) { + break; + } + candidates.pop(); + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + ++distcalls_search_layer; + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + if (blacklist.isSet(e_id)) continue; + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } +} + +LinkList +HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const +{ + LinkList result; + result.reserve(curMax+1); + NearestPriQ w; + for (const auto & entry : neighbors) { + w.push(entry); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (result.size() == curMax || haveCloserDistance(e, result)) { + lost.push_back(e.docid); + } else { + result.push_back(e.docid); + } + } + return result; +} + +#define NO_BACKFILL +#ifdef NO_BACKFILL +LinkList +HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const +{ + LinkList result; + result.reserve(curMax+1); + NearestPriQ w; + for (const auto & entry : neighbors) { + w.push(entry); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, result)) { + continue; + } + result.push_back(e.docid); + if (result.size() == curMax) { + ++select_n_full; + return result; + } + } + ++select_n_partial; + return result; +} +#else +LinkList +HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const +{ + LinkList result; + result.reserve(curMax+1); + bool needFiltering = (neighbors.size() > curMax); + NearestPriQ w; + for (const auto & entry : neighbors) { + w.push(entry); + } + LinkList backfill; + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (needFiltering && haveCloserDistance(e, result)) { + backfill.push_back(e.docid); + continue; + } + result.push_back(e.docid); + if (result.size() == curMax) return result; + } + if (result.size() * 4 < _M) { + for (uint32_t fill_id : backfill) { + result.push_back(fill_id); + if (result.size() * 2 >= _M) break; + } + } + return result; +} +#endif + +void +HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level) { + LinkList &newLinks = getLinkList(id, level); + for (uint32_t neigh_id : neighbors) { + LinkList &oldLinks = getLinkList(neigh_id, level); + newLinks.push_back(neigh_id); + oldLinks.push_back(id); + } +#define DISCONNECT_OLD_WEAK_LINKS +#ifdef DISCONNECT_OLD_WEAK_LINKS + for (uint32_t i = 1; i < neighbors.size(); ++i) { + uint32_t n_1 = neighbors[i]; + LinkList &links_1 = getLinkList(n_1, level); + for (uint32_t j = 0; j < i; ++j) { + uint32_t n_2 = neighbors[j]; + if (links_1.has_link_to(n_2)) { + ++disconnected_weak_links; + LinkList &links_2 = getLinkList(n_2, level); + links_1.remove_link(n_2); + links_2.remove_link(n_1); + } + } + } +#endif +} + +uint32_t +HnswLikeNns::count_reachable() const { + VisitedSet visited(_nodes.size()); + int level = _entryLevel; + LinkList curList; + curList.push_back(_entryId); + visited.mark(_entryId); + uint32_t idx = 0; + while (level >= 0) { + while (idx < curList.size()) { + uint32_t id = curList[idx++]; + const LinkList &links = getLinkList(id, level); + for (uint32_t n_id : links) { + if (visited.isMarked(n_id)) continue; + visited.mark(n_id); + curList.push_back(n_id); + } + } + --level; + idx = 0; + } + return curList.size(); +} + +void +HnswLikeNns::dumpStats() const { + std::vector levelCounts; + levelCounts.resize(_entryLevel + 2); + std::vector outLinkHist; + outLinkHist.resize(2 * _M + 2); + uint32_t symmetrics = 0; + uint32_t level1links = 0; + uint32_t both_l_links = 0; + fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n", + _nodes.size(), _entryLevel, _entryId); + + for (uint32_t id = 0; id < _nodes.size(); ++id) { + const auto &node = _nodes[id]; + uint32_t levels = node._links.size(); + levelCounts[levels]++; + if (levels < 1) { + outLinkHist[0]++; + continue; + } + const LinkList &link_list = getLinkList(id, 0); + uint32_t numlinks = link_list.size(); + outLinkHist[numlinks]++; + if (numlinks < 1) { + fprintf(stderr, "node with %u links: id %u\n", numlinks, id); + } + bool all_sym = true; + for (uint32_t n_id : link_list) { + const LinkList &neigh_list = getLinkList(n_id, 0); + if (! neigh_list.has_link_to(id)) { +#ifdef KEEP_SYM + fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id); +#endif + all_sym = false; + } + } + if (all_sym) ++symmetrics; + if (levels < 2) continue; + const LinkList &link_list_1 = getLinkList(id, 1); + for (uint32_t n_id : link_list_1) { + ++level1links; + if (link_list.has_link_to(n_id)) ++both_l_links; + } + } + for (uint32_t l = 0; l < levelCounts.size(); ++l) { + fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]); + } + fprintf(stderr, "reachable nodes %u / %zu\n", + count_reachable(), _nodes.size() - levelCounts[0]); + fprintf(stderr, "level 1 links overlapping on l0: %u / total: %u\n", + both_l_links, level1links); + for (uint32_t l = 0; l < outLinkHist.size(); ++l) { + if (outLinkHist[l] != 0) { + fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]); + } + } + fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics); +} + +std::unique_ptr> +make_hnsw_nns(uint32_t numDims, const DocVectorAccess &dva) +{ + return std::make_unique(numDims, dva); +} -- cgit v1.2.3 From cc3c709d6278ebd699d4f4c67f8f769c9b6fa177 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 12 Feb 2020 10:30:39 +0000 Subject: add and verify filter option split out common subroutines --- eval/src/tests/ann/CMakeLists.txt | 10 ++ eval/src/tests/ann/find-with-nns.h | 12 ++ eval/src/tests/ann/for-sift-top-k.h | 2 +- eval/src/tests/ann/gist_benchmark.cpp | 295 +++++++++++++++++++++++++++++++++ eval/src/tests/ann/nns.h | 26 +++ eval/src/tests/ann/quality-nns.h | 42 +++++ eval/src/tests/ann/remove-bm.cpp | 258 ++++++++-------------------- eval/src/tests/ann/sift_benchmark.cpp | 193 ++++++++++++++------- eval/src/tests/ann/verify-top-k.h | 27 +++ eval/src/tests/ann/xp-annoy-nns.cpp | 58 +++++++ eval/src/tests/ann/xp-hnsw-wrap.cpp | 28 ++++ eval/src/tests/ann/xp-hnswlike-nns.cpp | 121 ++++++++++++-- eval/src/tests/ann/xp-lsh-nns.cpp | 40 +++++ 13 files changed, 853 insertions(+), 259 deletions(-) create mode 100644 eval/src/tests/ann/find-with-nns.h create mode 100644 eval/src/tests/ann/gist_benchmark.cpp create mode 100644 eval/src/tests/ann/quality-nns.h create mode 100644 eval/src/tests/ann/verify-top-k.h diff --git a/eval/src/tests/ann/CMakeLists.txt b/eval/src/tests/ann/CMakeLists.txt index 52b4d675d9c..34babf1412f 100644 --- a/eval/src/tests/ann/CMakeLists.txt +++ b/eval/src/tests/ann/CMakeLists.txt @@ -10,6 +10,16 @@ vespa_add_executable(eval_sift_benchmark_app vespaeval ) +vespa_add_executable(eval_gist_benchmark_app + SOURCES + gist_benchmark.cpp + xp-annoy-nns.cpp + xp-hnswlike-nns.cpp + xp-lsh-nns.cpp + DEPENDS + vespaeval +) + vespa_add_executable(eval_remove_bm_app SOURCES remove-bm.cpp diff --git a/eval/src/tests/ann/find-with-nns.h b/eval/src/tests/ann/find-with-nns.h new file mode 100644 index 00000000000..3481b403f86 --- /dev/null +++ b/eval/src/tests/ann/find-with-nns.h @@ -0,0 +1,12 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +TopK find_with_nns(uint32_t sk, NNS_API &nns, uint32_t qid) { + TopK result; + const PointVector &qv = generatedQueries[qid]; + vespalib::ConstArrayRef query(qv.v, NUM_DIMS); + auto rv = nns.topK(result.K, query, sk); + for (size_t i = 0; i < result.K; ++i) { + result.hits[i] = Hit(rv[i].docid, rv[i].sq.distance); + } + return result; +} diff --git a/eval/src/tests/ann/for-sift-top-k.h b/eval/src/tests/ann/for-sift-top-k.h index ba91cb2aebc..8a659a507bc 100644 --- a/eval/src/tests/ann/for-sift-top-k.h +++ b/eval/src/tests/ann/for-sift-top-k.h @@ -6,7 +6,7 @@ struct TopK { static constexpr size_t K = 100; Hit hits[K]; - size_t recall(const TopK &other) { + size_t recall(const TopK &other) const { size_t overlap = 0; size_t i = 0; size_t j = 0; diff --git a/eval/src/tests/ann/gist_benchmark.cpp b/eval/src/tests/ann/gist_benchmark.cpp new file mode 100644 index 00000000000..45559fc2557 --- /dev/null +++ b/eval/src/tests/ann/gist_benchmark.cpp @@ -0,0 +1,295 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include +#include +#include +#include +#include + +#define NUM_DIMS 960 +#define NUM_DOCS 200000 +#define NUM_REACH 10000 +#define NUM_Q 1000 + +#include "doc_vector_access.h" +#include "nns.h" +#include "for-sift-hit.h" +#include "for-sift-top-k.h" + +std::vector bruteforceResults; + +struct PointVector { + float v[NUM_DIMS]; + using ConstArr = vespalib::ConstArrayRef; + operator ConstArr() const { return ConstArr(v, NUM_DIMS); } +}; + +static PointVector *aligned_alloc(size_t num) { + size_t num_bytes = num * sizeof(PointVector); + double mega_bytes = num_bytes / (1024.0*1024.0); + fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); + char *mem = (char *)malloc(num_bytes + 512); + mem += 512; + size_t val = (size_t)mem; + size_t unalign = val % 512; + mem -= unalign; + return reinterpret_cast(mem); +} + +static PointVector *generatedQueries = aligned_alloc(NUM_Q); +static PointVector *generatedDocs = aligned_alloc(NUM_DOCS); + +struct DocVectorAdapter : public DocVectorAccess +{ + vespalib::ConstArrayRef get(uint32_t docid) const override { + ASSERT_TRUE(docid < NUM_DOCS); + return generatedDocs[docid]; + } +}; + +double computeDistance(const PointVector &query, uint32_t docid) { + const PointVector &docvector = generatedDocs[docid]; + return l2distCalc.l2sq_dist(query, docvector); +} + +void read_queries(std::string fn) { + int fd = open(fn.c_str(), O_RDONLY); + ASSERT_TRUE(fd > 0); + int d; + size_t rv; + fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str()); + for (uint32_t qid = 0; qid < NUM_Q; ++qid) { + rv = read(fd, &d, 4); + ASSERT_EQUAL(rv, 4u); + ASSERT_EQUAL(d, NUM_DIMS); + rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float)); + ASSERT_EQUAL(rv, sizeof(PointVector)); + } + close(fd); +} + +void read_docs(std::string fn) { + int fd = open(fn.c_str(), O_RDONLY); + ASSERT_TRUE(fd > 0); + int d; + size_t rv; + fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str()); + for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { + rv = read(fd, &d, 4); + ASSERT_EQUAL(rv, 4u); + ASSERT_EQUAL(d, NUM_DIMS); + rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float)); + ASSERT_EQUAL(rv, sizeof(PointVector)); + } + close(fd); +} + +using TimePoint = std::chrono::steady_clock::time_point; +using Duration = std::chrono::steady_clock::duration; + +double to_ms(Duration elapsed) { + std::chrono::duration ms(elapsed); + return ms.count(); +} + +void read_data(std::string dir) { + TimePoint bef = std::chrono::steady_clock::now(); + read_queries(dir + "/gist_query.fvecs"); + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef)); + bef = std::chrono::steady_clock::now(); + read_docs(dir + "/gist_base.fvecs"); + aft = std::chrono::steady_clock::now(); + fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef)); +} + + +struct BfHitComparator { + bool operator() (const Hit &lhs, const Hit& rhs) const { + if (lhs.distance < rhs.distance) return false; + if (lhs.distance > rhs.distance) return true; + return (lhs.docid > rhs.docid); + } +}; + +class BfHitHeap { +private: + size_t _size; + vespalib::PriorityQueue _priQ; +public: + explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() { + _priQ.reserve(maxSize); + } + ~BfHitHeap() {} + void maybe_use(const Hit &hit) { + if (_priQ.size() < _size) { + _priQ.push(hit); + } else if (hit.distance < _priQ.front().distance) { + _priQ.front() = hit; + _priQ.adjust(); + } + } + std::vector bestHits() { + std::vector result; + size_t i = _priQ.size(); + result.resize(i); + while (i-- > 0) { + result[i] = _priQ.front(); + _priQ.pop_front(); + } + return result; + } +}; + +TopK bruteforce_nns(const PointVector &query) { + TopK result; + BfHitHeap heap(result.K); + for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { + const PointVector &docvector = generatedDocs[docid]; + double d = l2distCalc.l2sq_dist(query, docvector); + Hit h(docid, d); + heap.maybe_use(h); + } + std::vector best = heap.bestHits(); + for (size_t i = 0; i < result.K; ++i) { + result.hits[i] = best[i]; + } + return result; +} + +void verifyBF(uint32_t qid) { + const PointVector &query = generatedQueries[qid]; + TopK &result = bruteforceResults[qid]; + double min_distance = result.hits[0].distance; + std::vector all_c2; + for (uint32_t i = 0; i < NUM_DOCS; ++i) { + double dist = computeDistance(query, i); + if (dist < min_distance) { + fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance); + } + EXPECT_FALSE(dist+0.000001 < min_distance); + if (min_distance > 0.0) all_c2.push_back(dist / min_distance); + } + if (all_c2.size() != NUM_DOCS) return; + std::sort(all_c2.begin(), all_c2.end()); + for (uint32_t idx : { 1, 3, 10, 30, 100, 300, 1000, 3000, NUM_DOCS/2, NUM_DOCS-1}) { + fprintf(stderr, "c2-factor[%u] = %.3f\n", idx, all_c2[idx]); + } +} + +using NNS_API = NNS; + +TEST("require that brute force works") { + TimePoint bef = std::chrono::steady_clock::now(); + fprintf(stderr, "generating %u brute force results\n", NUM_Q); + bruteforceResults.reserve(NUM_Q); + for (uint32_t cnt = 0; cnt < NUM_Q; ++cnt) { + const PointVector &query = generatedQueries[cnt]; + bruteforceResults.emplace_back(bruteforce_nns(query)); + } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "timing for brute force: %.3f ms = %.3f ms per query\n", + to_ms(aft - bef), to_ms(aft - bef)/NUM_Q); + for (int cnt = 0; cnt < NUM_Q; cnt = (cnt+1)*2) { + verifyBF(cnt); + } +} + +#include "find-with-nns.h" +#include "verify-top-k.h" + +void timing_nns(const char *name, NNS_API &nns, std::vector sk_list) { + for (uint32_t search_k : sk_list) { + TimePoint bef = std::chrono::steady_clock::now(); + for (int cnt = 0; cnt < NUM_Q; ++cnt) { + find_with_nns(search_k, nns, cnt); + } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "timing for %s search_k=%u: %.3f ms = %.3f ms/q\n", + name, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q); + } +} + +#include "quality-nns.h" + +template +void bm_nns_simple(const char *name, FUNC creator, std::vector sk_list) { + std::unique_ptr nnsp = creator(); + NNS_API &nns = *nnsp; + fprintf(stderr, "trying %s indexing...\n", name); + TimePoint bef = std::chrono::steady_clock::now(); + for (uint32_t i = 0; i < NUM_DOCS; ++i) { + nns.addDoc(i); + } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, NUM_DOCS, to_ms(aft - bef)); + timing_nns(name, nns, sk_list); + fprintf(stderr, "Quality for %s [A] clean build with %u documents:\n", name, NUM_DOCS); + quality_nns(nns, sk_list); +} + +template +void benchmark_nns(const char *name, FUNC creator, std::vector sk_list) { + bm_nns_simple(name, creator, sk_list); +} + +#if 0 +TEST("require that Locality Sensitive Hashing mostly works") { + DocVectorAdapter adapter; + auto creator = [&adapter]() { return make_rplsh_nns(NUM_DIMS, adapter); }; + benchmark_nns("RPLSH", creator, { 200, 1000 }); +} +#endif + +#if 0 +TEST("require that Annoy via NNS api mostly works") { + DocVectorAdapter adapter; + auto creator = [&adapter]() { return make_annoy_nns(NUM_DIMS, adapter); }; + benchmark_nns("Annoy", creator, { 8000, 10000 }); +} +#endif + +#if 1 +TEST("require that HNSW via NNS api mostly works") { + DocVectorAdapter adapter; + auto creator = [&adapter]() { return make_hnsw_nns(NUM_DIMS, adapter); }; + benchmark_nns("HNSW-like", creator, { 100, 150, 200 }); +} +#endif + +#if 0 +TEST("require that HNSW wrapped api mostly works") { + DocVectorAdapter adapter; + auto creator = [&adapter]() { return make_hnsw_wrap(NUM_DIMS, adapter); }; + benchmark_nns("HNSW-wrap", creator, { 100, 150, 200 }); +} +#endif + +/** + * Before running the benchmark the ANN_GIST1M data set must be downloaded and extracted: + * wget ftp://ftp.irisa.fr/local/texmex/corpus/gist.tar.gz + * tar -xf gist.tar.gz + * + * The benchmark program will load the data set from $HOME/gist if no directory is specified. + * + * More information about the dataset is found here: http://corpus-texmex.irisa.fr/. + */ +int main(int argc, char **argv) { + TEST_MASTER.init(__FILE__); + std::string gist_dir = "."; + if (argc > 1) { + gist_dir = argv[1]; + } else { + char *home = getenv("HOME"); + if (home) { + gist_dir = home; + gist_dir += "/gist"; + } + } + read_data(gist_dir); + TEST_RUN_ALL(); + return (TEST_MASTER.fini() ? 0 : 1); +} diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h index ffe2882188e..ef3e4b5d69c 100644 --- a/eval/src/tests/ann/nns.h +++ b/eval/src/tests/ann/nns.h @@ -37,6 +37,31 @@ struct NnsHitComparatorLessDocid { } }; +class BitVector { +private: + std::vector _bits; +public: + BitVector(size_t sz) : _bits((sz+63)/64) {} + BitVector& setBit(size_t idx) { + uint64_t mask = 1; + mask <<= (idx%64); + _bits[idx/64] |= mask; + return *this; + } + bool isSet(size_t idx) const { + uint64_t mask = 1; + mask <<= (idx%64); + uint64_t word = _bits[idx/64]; + return (word & mask) != 0; + } + BitVector& clearBit(size_t idx) { + uint64_t mask = 1; + mask <<= (idx%64); + _bits[idx/64] &= ~mask; + return *this; + } +}; + template class NNS { @@ -50,6 +75,7 @@ public: using Vector = vespalib::ConstArrayRef; virtual std::vector topK(uint32_t k, Vector vector, uint32_t search_k) = 0; + virtual std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) = 0; virtual ~NNS() {} protected: uint32_t _numDims; diff --git a/eval/src/tests/ann/quality-nns.h b/eval/src/tests/ann/quality-nns.h new file mode 100644 index 00000000000..9ac37f0ef04 --- /dev/null +++ b/eval/src/tests/ann/quality-nns.h @@ -0,0 +1,42 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +bool reach_with_nns_k(NNS_API &nns, uint32_t docid, uint32_t k) { + const PointVector &qv = generatedDocs[docid]; + vespalib::ConstArrayRef query(qv.v, NUM_DIMS); + auto rv = nns.topK(k, query, k); + if (rv.size() != k) { + fprintf(stderr, "Result/K=%u from query for %u is %zu hits\n", + k, docid, rv.size()); + return false; + } + if (rv[0].docid != docid) { + if (rv[0].sq.distance != 0.0) + fprintf(stderr, "Expected/K=%u to find %u but got %u with sq distance %.3f\n", + k, docid, rv[0].docid, rv[0].sq.distance); + } + return (rv[0].docid == docid || rv[0].sq.distance == 0.0); +} + +void quality_nns(NNS_API &nns, std::vector sk_list) { + for (uint32_t search_k : sk_list) { + double sum_recall = 0; + for (int cnt = 0; cnt < NUM_Q; ++cnt) { + sum_recall += verify_nns_quality(search_k, nns, cnt); + } + fprintf(stderr, "Overall average recall: %.2f\n", sum_recall / NUM_Q); + } + for (uint32_t search_k : { 1, 10, 100, 1000 }) { + TimePoint bef = std::chrono::steady_clock::now(); + uint32_t reached = 0; + for (uint32_t i = 0; i < NUM_REACH; ++i) { + uint32_t target = i * (NUM_DOCS / NUM_REACH); + if (reach_with_nns_k(nns, target, search_k)) ++reached; + } + fprintf(stderr, "Could reach %u of %u documents with k=%u\n", + reached, NUM_REACH, search_k); + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "reach time k=%u: %.3f ms = %.3f ms/q\n", + search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_REACH); + if (reached == NUM_REACH) break; + } +} diff --git a/eval/src/tests/ann/remove-bm.cpp b/eval/src/tests/ann/remove-bm.cpp index be010552ab8..005f3804af9 100644 --- a/eval/src/tests/ann/remove-bm.cpp +++ b/eval/src/tests/ann/remove-bm.cpp @@ -13,6 +13,7 @@ #define NUM_DOCS 250000 #define NUM_DOCS_REMOVE 50000 #define EFFECTIVE_DOCS (NUM_DOCS - NUM_DOCS_REMOVE) +#define NUM_REACH 10000 #define NUM_Q 1000 #include "doc_vector_access.h" @@ -30,10 +31,10 @@ struct PointVector { }; static PointVector *aligned_alloc(size_t num) { - size_t sz = num * sizeof(PointVector); - double mega_bytes = sz / (1024.0*1024.0); + size_t num_bytes = num * sizeof(PointVector); + double mega_bytes = num_bytes / (1024.0*1024.0); fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); - char *mem = (char *)malloc(sz + 512); + char *mem = (char *)malloc(num_bytes + 512); mem += 512; size_t val = (size_t)mem; size_t unalign = val % 512; @@ -221,83 +222,8 @@ TEST("require that brute force works") { } } -bool reach_with_nns_1(NNS_API &nns, uint32_t docid) { - const PointVector &qv = generatedDocs[docid]; - vespalib::ConstArrayRef query(qv.v, NUM_DIMS); - auto rv = nns.topK(1, query, 1); - if (rv.size() != 1) { - fprintf(stderr, "Result/A from query for %u is %zu hits\n", docid, rv.size()); - return false; - } - if (rv[0].docid != docid) { - if (rv[0].sq.distance != 0.0) - fprintf(stderr, "Expected/A to find %u but got %u with sq distance %.3f\n", - docid, rv[0].docid, rv[0].sq.distance); - } - return (rv[0].docid == docid || rv[0].sq.distance == 0.0); -} - -bool reach_with_nns_100(NNS_API &nns, uint32_t docid) { - const PointVector &qv = generatedDocs[docid]; - vespalib::ConstArrayRef query(qv.v, NUM_DIMS); - auto rv = nns.topK(10, query, 100); - if (rv.size() != 10) { - fprintf(stderr, "Result/B from query for %u is %zu hits\n", docid, rv.size()); - } - if (rv[0].docid != docid) { - if (rv[0].sq.distance != 0.0) - fprintf(stderr, "Expected/B to find %u but got %u with sq distance %.3f\n", - docid, rv[0].docid, rv[0].sq.distance); - } - return (rv[0].docid == docid || rv[0].sq.distance == 0.0); -} - -bool reach_with_nns_1k(NNS_API &nns, uint32_t docid) { - const PointVector &qv = generatedDocs[docid]; - vespalib::ConstArrayRef query(qv.v, NUM_DIMS); - auto rv = nns.topK(10, query, 1000); - if (rv.size() != 10) { - fprintf(stderr, "Result/C from query for %u is %zu hits\n", docid, rv.size()); - } - if (rv[0].docid != docid) { - if (rv[0].sq.distance != 0.0) - fprintf(stderr, "Expected/C to find %u but got %u with sq distance %.3f\n", - docid, rv[0].docid, rv[0].sq.distance); - } - return (rv[0].docid == docid || rv[0].sq.distance == 0.0); -} - -TopK find_with_nns(uint32_t sk, NNS_API &nns, uint32_t qid) { - TopK result; - const PointVector &qv = generatedQueries[qid]; - vespalib::ConstArrayRef query(qv.v, NUM_DIMS); - auto rv = nns.topK(result.K, query, sk); - for (size_t i = 0; i < result.K; ++i) { - result.hits[i] = Hit(rv[i].docid, rv[i].sq.distance); - } - return result; -} - -void verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) { - TopK perfect = bruteforceResults[qid]; - TopK result = find_with_nns(sk, nns, qid); - int recall = perfect.recall(result); - EXPECT_TRUE(recall > 40); - double sum_error = 0.0; - double c_factor = 1.0; - for (size_t i = 0; i < result.K; ++i) { - double factor = (result.hits[i].distance / perfect.hits[i].distance); - if (factor < 0.99 || factor > 25) { - fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n", - i, result.hits[i].distance, perfect.hits[i].distance); - } - sum_error += factor; - c_factor = std::max(c_factor, factor); - } - EXPECT_TRUE(c_factor < 1.5); - fprintf(stderr, "quality sk=%u: query %u: recall %d c2-factor %.3f avg c2: %.3f\n", - sk, qid, recall, c_factor, sum_error / result.K); -} +#include "find-with-nns.h" +#include "verify-top-k.h" void timing_nns(const char *name, NNS_API &nns, std::vector sk_list) { for (uint32_t search_k : sk_list) { @@ -311,64 +237,22 @@ void timing_nns(const char *name, NNS_API &nns, std::vector sk_list) { } } -void quality_nns(NNS_API &nns, std::vector sk_list) { - for (uint32_t search_k : sk_list) { - for (int cnt = 0; cnt < NUM_Q; ++cnt) { - verify_nns_quality(search_k, nns, cnt); - } - } - uint32_t reached = 0; - for (uint32_t i = 0; i < 20000; ++i) { - if (reach_with_nns_1(nns, i)) ++reached; - } - fprintf(stderr, "Could reach %u of 20000 first documents with k=1\n", reached); - reached = 0; - for (uint32_t i = 0; i < 20000; ++i) { - if (reach_with_nns_100(nns, i)) ++reached; - } - fprintf(stderr, "Could reach %u of 20000 first documents with k=100\n", reached); - reached = 0; - for (uint32_t i = 0; i < 20000; ++i) { - if (reach_with_nns_1k(nns, i)) ++reached; - } - fprintf(stderr, "Could reach %u of 20000 first documents with k=1000\n", reached); -} +#include "quality-nns.h" -void benchmark_nns(const char *name, NNS_API &nns, std::vector sk_list) { +template +void bm_nns_simple(const char *name, FUNC creator, std::vector sk_list) { + std::unique_ptr nnsp = creator(); + NNS_API &nns = *nnsp; fprintf(stderr, "trying %s indexing...\n", name); - -#if 0 - TimePoint bef = std::chrono::steady_clock::now(); - for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { - nns.addDoc(EFFECTIVE_DOCS + i); - } - for (uint32_t i = 0; i < EFFECTIVE_DOCS - NUM_DOCS_REMOVE; ++i) { - nns.addDoc(i); - } - for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { - nns.removeDoc(EFFECTIVE_DOCS + i); - nns.addDoc(EFFECTIVE_DOCS - NUM_DOCS_REMOVE + i); - } - TimePoint aft = std::chrono::steady_clock::now(); - fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef)); - - timing_nns(name, nns, sk_list); - fprintf(stderr, "Quality for %s realistic build with %u documents:\n", name, EFFECTIVE_DOCS); - quality_nns(nns, sk_list); -#endif - -#if 1 TimePoint bef = std::chrono::steady_clock::now(); for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) { nns.addDoc(i); } TimePoint aft = std::chrono::steady_clock::now(); fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef)); - timing_nns(name, nns, sk_list); - fprintf(stderr, "Quality for %s clean build with %u documents:\n", name, EFFECTIVE_DOCS); + fprintf(stderr, "Quality for %s [A] clean build with %u documents:\n", name, EFFECTIVE_DOCS); quality_nns(nns, sk_list); - bef = std::chrono::steady_clock::now(); for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { nns.addDoc(EFFECTIVE_DOCS + i); @@ -379,111 +263,115 @@ void benchmark_nns(const char *name, NNS_API &nns, std::vector sk_list aft = std::chrono::steady_clock::now(); fprintf(stderr, "build %s index add then remove %u docs: %.3f ms\n", name, NUM_DOCS_REMOVE, to_ms(aft - bef)); - timing_nns(name, nns, sk_list); - fprintf(stderr, "Quality for %s remove-damaged build with %u documents:\n", name, EFFECTIVE_DOCS); + fprintf(stderr, "Quality for %s [B] remove-damaged build with %u documents:\n", name, EFFECTIVE_DOCS); quality_nns(nns, sk_list); -#endif +} -#if 0 +template +void bm_nns_remove_old(const char *name, FUNC creator, std::vector sk_list) { + std::unique_ptr nnsp = creator(); + NNS_API &nns = *nnsp; TimePoint bef = std::chrono::steady_clock::now(); + for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { + nns.addDoc(EFFECTIVE_DOCS + i); + } for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) { nns.addDoc(i); } + for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { + nns.removeDoc(EFFECTIVE_DOCS + i); + } TimePoint aft = std::chrono::steady_clock::now(); fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef)); - timing_nns(name, nns, sk_list); - fprintf(stderr, "Quality for %s clean build with %u documents:\n", name, EFFECTIVE_DOCS); + fprintf(stderr, "Quality for %s [C] remove-oldest build with %u documents:\n", name, EFFECTIVE_DOCS); quality_nns(nns, sk_list); +} - bef = std::chrono::steady_clock::now(); - for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) { - nns.removeDoc(i); - } - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "build %s index removed %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef)); - - const uint32_t addFirst = NUM_DOCS - (NUM_DOCS_REMOVE * 3); - const uint32_t addSecond = NUM_DOCS - (NUM_DOCS_REMOVE * 2); - - bef = std::chrono::steady_clock::now(); - for (uint32_t i = 0; i < addFirst; ++i) { - nns.addDoc(i); - } - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, addFirst, to_ms(aft - bef)); - - bef = std::chrono::steady_clock::now(); +template +void bm_nns_interleave(const char *name, FUNC creator, std::vector sk_list) { + std::unique_ptr nnsp = creator(); + NNS_API &nns = *nnsp; + TimePoint bef = std::chrono::steady_clock::now(); for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { nns.addDoc(EFFECTIVE_DOCS + i); - nns.addDoc(addFirst + i); } - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "build %s index added %u docs: %.3f ms\n", - name, 2 * NUM_DOCS_REMOVE, to_ms(aft - bef)); - - bef = std::chrono::steady_clock::now(); + for (uint32_t i = 0; i < EFFECTIVE_DOCS - NUM_DOCS_REMOVE; ++i) { + nns.addDoc(i); + } for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { nns.removeDoc(EFFECTIVE_DOCS + i); - nns.addDoc(addSecond + i); + nns.addDoc(EFFECTIVE_DOCS - NUM_DOCS_REMOVE + i); } - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "build %s index added %u and removed %u docs: %.3f ms\n", - name, NUM_DOCS_REMOVE, NUM_DOCS_REMOVE, to_ms(aft - bef)); - + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef)); timing_nns(name, nns, sk_list); - fprintf(stderr, "Quality for %s with %u documents some churn:\n", name, EFFECTIVE_DOCS); + fprintf(stderr, "Quality for %s [D] realistic build with %u documents:\n", name, EFFECTIVE_DOCS); quality_nns(nns, sk_list); +} -#endif - -#if 0 - bef = std::chrono::steady_clock::now(); - fprintf(stderr, "removing and adding %u documents...\n", EFFECTIVE_DOCS); - for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) { - nns.removeDoc(i); +template +void bm_nns_remove_old_add_new(const char *name, FUNC creator, std::vector sk_list) { + std::unique_ptr nnsp = creator(); + NNS_API &nns = *nnsp; + TimePoint bef = std::chrono::steady_clock::now(); + for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { + nns.addDoc(EFFECTIVE_DOCS + i); + } + for (uint32_t i = 0; i < EFFECTIVE_DOCS - NUM_DOCS_REMOVE; ++i) { nns.addDoc(i); } - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "build %s index rem/add %u docs: %.3f ms\n", - name, EFFECTIVE_DOCS, to_ms(aft - bef)); - + for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { + nns.removeDoc(EFFECTIVE_DOCS + i); + } + for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) { + nns.addDoc(EFFECTIVE_DOCS - NUM_DOCS_REMOVE + i); + } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef)); timing_nns(name, nns, sk_list); - fprintf(stderr, "Quality for %s with %u documents full churn:\n", name, EFFECTIVE_DOCS); + fprintf(stderr, "Quality for %s [E] remove old, add new build with %u documents:\n", name, EFFECTIVE_DOCS); quality_nns(nns, sk_list); -#endif +} + +template +void benchmark_nns(const char *name, FUNC creator, std::vector sk_list) { + bm_nns_simple(name, creator, sk_list); + bm_nns_remove_old(name, creator, sk_list); + bm_nns_interleave(name, creator, sk_list); + bm_nns_remove_old_add_new(name, creator, sk_list); } #if 0 TEST("require that Locality Sensitive Hashing mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_rplsh_nns(NUM_DIMS, adapter); - benchmark_nns("RPLSH", *nns, { 200, 1000 }); + auto creator = [&adapter]() { return make_rplsh_nns(NUM_DIMS, adapter); }; + benchmark_nns("RPLSH", creator, { 200, 1000 }); } #endif #if 0 TEST("require that Annoy via NNS api mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_annoy_nns(NUM_DIMS, adapter); - benchmark_nns("Annoy", *nns, { 8000, 10000 }); + auto creator = [&adapter]() { return make_annoy_nns(NUM_DIMS, adapter); }; + benchmark_nns("Annoy", creator, { 8000, 10000 }); } #endif #if 1 TEST("require that HNSW via NNS api mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_hnsw_nns(NUM_DIMS, adapter); - benchmark_nns("HNSW-like", *nns, { 100, 150, 200 }); + auto creator = [&adapter]() { return make_hnsw_nns(NUM_DIMS, adapter); }; + benchmark_nns("HNSW-like", creator, { 100, 150, 200 }); } #endif #if 0 TEST("require that HNSW wrapped api mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_hnsw_wrap(NUM_DIMS, adapter); - benchmark_nns("HNSW-wrap", *nns, { 100, 150, 200 }); + auto creator = [&adapter]() { return make_hnsw_wrap(NUM_DIMS, adapter); }; + benchmark_nns("HNSW-wrap", creator, { 100, 150, 200 }); } #endif diff --git a/eval/src/tests/ann/sift_benchmark.cpp b/eval/src/tests/ann/sift_benchmark.cpp index 022c9404f5d..5f3c16e127d 100644 --- a/eval/src/tests/ann/sift_benchmark.cpp +++ b/eval/src/tests/ann/sift_benchmark.cpp @@ -13,14 +13,15 @@ #define NUM_DIMS 128 #define NUM_DOCS 1000000 #define NUM_Q 1000 +#define NUM_REACH 10000 #include "doc_vector_access.h" #include "nns.h" #include "for-sift-hit.h" #include "for-sift-top-k.h" +#include "std-random.h" std::vector bruteforceResults; -std::vector tmp_v(NUM_DIMS); struct PointVector { float v[NUM_DIMS]; @@ -29,10 +30,10 @@ struct PointVector { }; static PointVector *aligned_alloc(size_t num) { - size_t sz = num * sizeof(PointVector); - double mega_bytes = sz / (1024.0*1024.0); + size_t num_bytes = num * sizeof(PointVector); + double mega_bytes = num_bytes / (1024.0*1024.0); fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); - char *mem = (char *)malloc(sz + 512); + char *mem = (char *)malloc(num_bytes + 512); mem += 512; size_t val = (size_t)mem; size_t unalign = val % 512; @@ -53,7 +54,7 @@ struct DocVectorAdapter : public DocVectorAccess double computeDistance(const PointVector &query, uint32_t docid) { const PointVector &docvector = generatedDocs[docid]; - return l2distCalc.l2sq_dist(query, docvector, tmp_v); + return l2distCalc.l2sq_dist(query, docvector); } void read_queries(std::string fn) { @@ -151,7 +152,7 @@ TopK bruteforce_nns(const PointVector &query) { BfHitHeap heap(result.K); for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { const PointVector &docvector = generatedDocs[docid]; - double d = l2distCalc.l2sq_dist(query, docvector, tmp_v); + double d = l2distCalc.l2sq_dist(query, docvector); Hit h(docid, d); heap.maybe_use(h); } @@ -162,24 +163,58 @@ TopK bruteforce_nns(const PointVector &query) { return result; } +TopK bruteforce_nns_filter(const PointVector &query, const BitVector &blacklist) { + TopK result; + BfHitHeap heap(result.K); + for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { + if (blacklist.isSet(docid)) continue; + const PointVector &docvector = generatedDocs[docid]; + double d = l2distCalc.l2sq_dist(query, docvector); + Hit h(docid, d); + heap.maybe_use(h); + } + std::vector best = heap.bestHits(); + EXPECT_EQUAL(best.size(), result.K); + for (size_t i = 0; i < result.K; ++i) { + result.hits[i] = best[i]; + } + return result; +} + + void verifyBF(uint32_t qid) { const PointVector &query = generatedQueries[qid]; TopK &result = bruteforceResults[qid]; double min_distance = result.hits[0].distance; - std::vector all_c2; for (uint32_t i = 0; i < NUM_DOCS; ++i) { double dist = computeDistance(query, i); if (dist < min_distance) { fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance); } EXPECT_FALSE(dist+0.000001 < min_distance); - if (min_distance > 0) all_c2.push_back(dist / min_distance); } - if (all_c2.size() != NUM_DOCS) return; - std::sort(all_c2.begin(), all_c2.end()); - for (uint32_t idx : { 1, 3, 10, 30, 100, 300, 1000, 3000, NUM_DOCS/2, NUM_DOCS-1}) { - fprintf(stderr, "c2-factor[%u] = %.3f\n", idx, all_c2[idx]); +} + +void timing_bf_filter(int percent) +{ + BitVector blacklist(NUM_DOCS); + RndGen rnd; + for (uint32_t idx = 0; idx < NUM_DOCS; ++idx) { + if (rnd.nextUniform() < 0.01 * percent) { + blacklist.setBit(idx); + } else { + blacklist.clearBit(idx); + } + } + TimePoint bef = std::chrono::steady_clock::now(); + for (int cnt = 0; cnt < NUM_Q; ++cnt) { + const PointVector &qv = generatedQueries[cnt]; + auto res = bruteforce_nns_filter(qv, blacklist); + EXPECT_TRUE(res.hits[res.K - 1].distance > 0.0); } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "timing for bruteforce filter %d %%: %.3f ms = %.3f ms/q\n", + percent, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q); } TEST("require that brute force works") { @@ -195,52 +230,90 @@ TEST("require that brute force works") { for (int cnt = 0; cnt < NUM_Q; cnt = (cnt+1)*2) { verifyBF(cnt); } +#if 1 + for (uint32_t filter_percent : { 0, 1, 10, 50, 90, 95, 99 }) { + timing_bf_filter(filter_percent); + } +#endif } using NNS_API = NNS; -TopK find_with_nns(uint32_t sk, NNS_API &nns, uint32_t qid) { - TopK result; +size_t search_with_filter(uint32_t sk, NNS_API &nns, uint32_t qid, + const BitVector &blacklist) +{ const PointVector &qv = generatedQueries[qid]; vespalib::ConstArrayRef query(qv.v, NUM_DIMS); - auto rv = nns.topK(result.K, query, sk); - for (size_t i = 0; i < result.K; ++i) { - result.hits[i] = Hit(rv[i].docid, rv[i].sq.distance); + auto rv = nns.topKfilter(100, query, sk, blacklist); + return rv.size(); +} + +#include "find-with-nns.h" +#include "verify-top-k.h" + +void verify_with_filter(uint32_t sk, NNS_API &nns, uint32_t qid, + const BitVector &blacklist) +{ + const PointVector &qv = generatedQueries[qid]; + auto expected = bruteforce_nns_filter(qv, blacklist); + vespalib::ConstArrayRef query(qv.v, NUM_DIMS); + auto rv = nns.topKfilter(expected.K, query, sk, blacklist); + TopK actual; + for (size_t i = 0; i < actual.K; ++i) { + actual.hits[i] = Hit(rv[i].docid, rv[i].sq.distance); } - return result; + verify_top_k(expected, actual, sk, qid); } -void verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) { - TopK perfect = bruteforceResults[qid]; - TopK result = find_with_nns(sk, nns, qid); - int recall = perfect.recall(result); - EXPECT_TRUE(recall > 40); - double sum_error = 0.0; - double c_factor = 1.0; - for (size_t i = 0; i < result.K; ++i) { - double factor = (result.hits[i].distance / perfect.hits[i].distance); - if (factor < 0.99 || factor > 25) { - fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n", - i, result.hits[i].distance, perfect.hits[i].distance); +void timing_nns_filter(const char *name, NNS_API &nns, + std::vector sk_list, int percent) +{ + BitVector blacklist(NUM_DOCS); + RndGen rnd; + for (uint32_t idx = 0; idx < NUM_DOCS; ++idx) { + if (rnd.nextUniform() < 0.01 * percent) { + blacklist.setBit(idx); + } else { + blacklist.clearBit(idx); + } + } + for (uint32_t search_k : sk_list) { + TimePoint bef = std::chrono::steady_clock::now(); + for (int cnt = 0; cnt < NUM_Q; ++cnt) { + uint32_t nh = search_with_filter(search_k, nns, cnt, blacklist); + EXPECT_EQUAL(nh, 100u); + } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "timing for %s filter %d %% search_k=%u: %.3f ms = %.3f ms/q\n", + name, percent, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q); +#if 0 + fprintf(stderr, "Quality check for %s filter %d %%:\n", name, percent); + for (int cnt = 0; cnt < NUM_Q; ++cnt) { + verify_with_filter(search_k, nns, cnt, blacklist); } - sum_error += factor; - c_factor = std::max(c_factor, factor); +#endif } - EXPECT_TRUE(c_factor < 1.5); - fprintf(stderr, "quality sk=%u: query %u: recall %d, c2-factor %.3f, avg c2: %.3f\n", - sk, qid, recall, c_factor, sum_error / result.K); - if (qid == 6) { - for (size_t i = 0; i < 10; ++i) { - fprintf(stderr, "topk[%zu] BF{%u %.3f} index{%u %.3f}\n", - i, - perfect.hits[i].docid, perfect.hits[i].distance, - result.hits[i].docid, result.hits[i].distance); +} + +void timing_nns(const char *name, NNS_API &nns, std::vector sk_list) { + for (uint32_t search_k : sk_list) { + TimePoint bef = std::chrono::steady_clock::now(); + for (int cnt = 0; cnt < NUM_Q; ++cnt) { + find_with_nns(search_k, nns, cnt); } + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "timing for %s search_k=%u: %.3f ms = %.3f ms/q\n", + name, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q); } } -void benchmark_nns(const char *name, NNS_API &nns, std::vector sk_list) { +#include "quality-nns.h" + +template +void benchmark_nns(const char *name, FUNC creator, std::vector sk_list) { fprintf(stderr, "trying %s indexing...\n", name); + std::unique_ptr nnsp = creator(); + NNS_API &nns = *nnsp; TimePoint bef = std::chrono::steady_clock::now(); for (uint32_t i = 0; i < NUM_DOCS; ++i) { nns.addDoc(i); @@ -250,50 +323,44 @@ void benchmark_nns(const char *name, NNS_API &nns, std::vector sk_list TimePoint aft = std::chrono::steady_clock::now(); fprintf(stderr, "build %s index: %.3f ms\n", name, to_ms(aft - bef)); - for (uint32_t search_k : sk_list) { - bef = std::chrono::steady_clock::now(); - for (int cnt = 0; cnt < NUM_Q; ++cnt) { - find_with_nns(search_k, nns, cnt); - } - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "timing for %s search_k=%u: %.3f ms = %.3f ms/q\n", - name, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q); - for (int cnt = 0; cnt < NUM_Q; ++cnt) { - verify_nns_quality(search_k, nns, cnt); - } + fprintf(stderr, "Timings for %s :\n", name); + timing_nns(name, nns, sk_list); + for (uint32_t filter_percent : { 0, 1, 10, 50, 90, 95, 99 }) { + timing_nns_filter(name, nns, sk_list, filter_percent); } + fprintf(stderr, "Quality for %s :\n", name); + quality_nns(nns, sk_list); } - -#if 1 +#if 0 TEST("require that Locality Sensitive Hashing mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_rplsh_nns(NUM_DIMS, adapter); - benchmark_nns("RPLSH", *nns, { 200, 1000 }); + auto creator = [&adapter]() { return make_rplsh_nns(NUM_DIMS, adapter); }; + benchmark_nns("RPLSH", creator, { 200, 1000 }); } #endif #if 1 TEST("require that Annoy via NNS api mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_annoy_nns(NUM_DIMS, adapter); - benchmark_nns("Annoy", *nns, { 8000, 10000 }); + auto creator = [&adapter]() { return make_annoy_nns(NUM_DIMS, adapter); }; + benchmark_nns("Annoy", creator, { 8000, 10000 }); } #endif #if 1 TEST("require that HNSW via NNS api mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_hnsw_nns(NUM_DIMS, adapter); - benchmark_nns("HNSW-like", *nns, { 100, 150, 200 }); + auto creator = [&adapter]() { return make_hnsw_nns(NUM_DIMS, adapter); }; + benchmark_nns("HNSW-like", creator, { 100, 150, 200 }); } #endif #if 0 TEST("require that HNSW wrapped api mostly works") { DocVectorAdapter adapter; - std::unique_ptr nns = make_hnsw_wrap(NUM_DIMS, adapter); - benchmark_nns("HNSW-wrap", *nns, { 100, 150, 200 }); + auto creator = [&adapter]() { return make_hnsw_wrap(NUM_DIMS, adapter); }; + benchmark_nns("HNSW-wrap", creator, { 100, 150, 200 }); } #endif diff --git a/eval/src/tests/ann/verify-top-k.h b/eval/src/tests/ann/verify-top-k.h new file mode 100644 index 00000000000..220c273d017 --- /dev/null +++ b/eval/src/tests/ann/verify-top-k.h @@ -0,0 +1,27 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +int verify_top_k(const TopK &perfect, const TopK &result, uint32_t sk, uint32_t qid) { + int recall = perfect.recall(result); + EXPECT_TRUE(recall > 40); + double sum_error = 0.0; + double c_factor = 1.0; + for (size_t i = 0; i < result.K; ++i) { + double factor = (result.hits[i].distance / perfect.hits[i].distance); + if (factor < 0.99 || factor > 25) { + fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n", + i, result.hits[i].distance, perfect.hits[i].distance); + } + sum_error += factor; + c_factor = std::max(c_factor, factor); + } + EXPECT_TRUE(c_factor < 1.5); + fprintf(stderr, "quality sk=%u: query %u: recall %d c2-factor %.3f avg c2: %.3f\n", + sk, qid, recall, c_factor, sum_error / result.K); + return recall; +} + +int verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) { + TopK perfect = bruteforceResults[qid]; + TopK result = find_with_nns(sk, nns, qid); + return verify_top_k(perfect, result, sk, qid); +} diff --git a/eval/src/tests/ann/xp-annoy-nns.cpp b/eval/src/tests/ann/xp-annoy-nns.cpp index f022aae5974..213e583d95a 100644 --- a/eval/src/tests/ann/xp-annoy-nns.cpp +++ b/eval/src/tests/ann/xp-annoy-nns.cpp @@ -27,6 +27,7 @@ struct Node { virtual Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) = 0; virtual int remove(uint32_t docid, V vector) = 0; virtual void findCandidates(std::set &cands, V vector, NodeQueue &queue, double minDist) const = 0; + virtual void filterCandidates(std::set &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const = 0; virtual void stats(std::vector &depths) = 0; }; @@ -38,6 +39,7 @@ struct LeafNode : public Node { Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override; int remove(uint32_t docid, V vector) override; void findCandidates(std::set &cands, V vector, NodeQueue &queue, double minDist) const override; + void filterCandidates(std::set &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override; Node *split(AnnoyLikeNns &meta); virtual void stats(std::vector &depths) override { depths.push_back(1); } @@ -55,6 +57,7 @@ struct SplitNode : public Node { Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override; int remove(uint32_t docid, V vector) override; void findCandidates(std::set &cands, V vector, NodeQueue &queue, double minDist) const override; + void filterCandidates(std::set &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override; double planeDistance(V vector) const; virtual void stats(std::vector &depths) override { @@ -106,6 +109,8 @@ public: } std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override; + std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override; + V getVector(uint32_t docid) const { return _dva.get(docid); } double uniformRnd() { return _rndGen.nextUniform(); } uint32_t dims() const { return _numDims; } @@ -304,6 +309,16 @@ LeafNode::findCandidates(std::set &cands, V, NodeQueue &, double) cons } } +void +LeafNode::filterCandidates(std::set &cands, V, NodeQueue &, double, const BitVector &blacklist) const +{ + for (uint32_t d : docids) { + if (blacklist.isSet(d)) continue; + cands.insert(d); + } +} + + SplitNode::~SplitNode() { delete leftChildren; @@ -344,6 +359,15 @@ SplitNode::findCandidates(std::set &, V vector, NodeQueue &queue, doub queue.push(std::make_pair(std::min(d, minDist), rightChildren)); } +void +SplitNode::filterCandidates(std::set &, V vector, NodeQueue &queue, double minDist, const BitVector &) const +{ + double d = planeDistance(vector); + // fprintf(stderr, "push 2 nodes dist %g\n", d); + queue.push(std::make_pair(std::min(-d, minDist), leftChildren)); + queue.push(std::make_pair(std::min(d, minDist), rightChildren)); +} + std::vector AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) { @@ -387,6 +411,40 @@ AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) return r; } +std::vector +AnnoyLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + ++find_top_k_cnt; + std::vector r; + r.reserve(k); + std::set candidates; + NodeQueue queue; + for (Node *root : _roots) { + double dist = std::numeric_limits::max(); + queue.push(std::make_pair(dist, root)); + } + while ((candidates.size() < std::max(k, search_k)) && (queue.size() > 0)) { + const QueueNode& top = queue.top(); + double md = top.first; + // fprintf(stderr, "find candidates: node with min distance %g\n", md); + Node *n = top.second; + queue.pop(); + n->filterCandidates(candidates, vector, queue, md, blacklist); + ++find_cand_cnt; + } + for (uint32_t docid : candidates) { + if (blacklist.isSet(docid)) continue; + double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid)); + NnsHit hit(docid, SqDist(dist)); + r.push_back(hit); + } + std::sort(r.begin(), r.end(), NnsHitComparatorLessDistance()); + while (r.size() > k) r.pop_back(); + return r; +} + + + void AnnoyLikeNns::dumpStats() { fprintf(stderr, "stats for AnnoyLikeNns:\n"); diff --git a/eval/src/tests/ann/xp-hnsw-wrap.cpp b/eval/src/tests/ann/xp-hnsw-wrap.cpp index 3eb01142dcd..45c7a974254 100644 --- a/eval/src/tests/ann/xp-hnsw-wrap.cpp +++ b/eval/src/tests/ann/xp-hnsw-wrap.cpp @@ -46,6 +46,34 @@ public: } return result; } + + std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override { + std::vector reversed; + uint32_t adjusted_k = k+4; + uint32_t adjusted_sk = search_k+4; + for (int retry = 0; (retry < 5) && (reversed.size() < k); ++retry) { + reversed.clear(); + _hnsw.setEf(adjusted_sk); + auto priQ = _hnsw.searchKnn(vector.cbegin(), adjusted_k); + while (! priQ.empty()) { + auto pair = priQ.top(); + if (! blacklist.isSet(pair.second)) { + reversed.emplace_back(pair.second, SqDist(pair.first)); + } + priQ.pop(); + } + double got = 1 + reversed.size(); + double factor = 1.25 * k / got; + adjusted_k *= factor; + adjusted_sk *= factor; + } + std::vector result; + while (result.size() < k && !reversed.empty()) { + result.push_back(reversed.back()); + reversed.pop_back(); + } + return result; + } }; std::unique_ptr> diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp index 5cdbdd8efa3..90fc0fe2e92 100644 --- a/eval/src/tests/ann/xp-hnswlike-nns.cpp +++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp @@ -32,6 +32,11 @@ static size_t distcalls_heuristic; static size_t distcalls_shrink; static size_t distcalls_refill; static size_t refill_needed_calls; +static size_t shrink_needed_calls; +static size_t disconnected_weak_links; +static size_t disconnected_for_symmetry; +static size_t select_n_full; +static size_t select_n_partial; struct LinkList : std::vector { @@ -76,6 +81,7 @@ struct VisitedSet ptr = (Mark *)malloc(size * sizeof(Mark)); curval = -1; sz = size; + clear(); } void clear() { ++curval; @@ -99,8 +105,9 @@ struct VisitedSetPool VisitedSet &get(size_t size) { if (size > lastUsed->sz) { lastUsed = std::make_unique(size*2); + } else { + lastUsed->clear(); } - lastUsed->clear(); return *lastUsed; } }; @@ -214,6 +221,10 @@ public: void search_layer(Vector vector, FurthestPriQ &w, uint32_t ef, uint32_t searchLevel); + void search_layer_with_filter(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist); + bool haveCloserDistance(HnswHit e, const LinkList &r) const { for (uint32_t prevId : r) { double dist = distance(e.docid, prevId); @@ -278,6 +289,10 @@ public: fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); } } @@ -315,10 +330,19 @@ public: LinkList lostLinks; LinkList oldLinks = links; links = remove_weakest(distances, maxLinks, lostLinks); +#define KEEP_SYM +#ifdef KEEP_SYM for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; remove_link_from(lost_id, shrink_id, level); + } +#define DO_REFILL_AFTER_KEEP_SYM +#ifdef DO_REFILL_AFTER_KEEP_SYM + for (uint32_t lost_id : lostLinks) { refill_ifneeded(lost_id, oldLinks, level); } +#endif +#endif } void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); @@ -337,7 +361,9 @@ public: } remove_link_from(n_id, docid, level); } - for (uint32_t n_id : my_links) { + while (! my_links.empty()) { + uint32_t n_id = my_links.back(); + my_links.pop_back(); refill_ifneeded(n_id, my_links, level); } } @@ -363,12 +389,12 @@ public: ++distcalls_other; HnswHit entryPoint(_entryId, SqDist(entryDist)); int searchLevel = _entryLevel; - FurthestPriQ w; - w.push(entryPoint); while (searchLevel > 0) { - search_layer(vector, w, std::min(k, search_k), searchLevel); + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); --searchLevel; } + FurthestPriQ w; + w.push(entryPoint); search_layer(vector, w, std::max(k, search_k), 0); while (w.size() > k) { w.pop(); @@ -381,8 +407,11 @@ public: } return result; } + + std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; }; + double HnswLikeNns::distance(Vector v, uint32_t b) const { @@ -390,12 +419,40 @@ HnswLikeNns::distance(Vector v, uint32_t b) const return l2distCalc.l2sq_dist(v, w); } +std::vector +HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + std::vector result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); + search_layer_with_filter(vector, w, std::max(k, search_k), 0, blacklist); + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(std::min((size_t)k, tmp.size())); + for (const auto & hit : tmp) { + if (blacklist.isSet(hit.docid)) continue; + result.emplace_back(hit.docid, SqDist(hit.dist)); + if (result.size() == k) break; + } + return result; +} + void HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) { uint32_t maxLinks = (level > 0) ? _M : (2 * _M); for (uint32_t old_id : neighbors) { LinkList &oldLinks = getLinkList(old_id, level); if (oldLinks.size() > maxLinks) { + ++shrink_needed_calls; shrink_links(old_id, maxLinks, level); } } @@ -437,6 +494,44 @@ HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w, return; } +void +HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist) +{ + NearestPriQ candidates; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + + for (const HnswHit & entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + if (blacklist.isSet(entry.docid)) ++ef; + } + double limd = std::numeric_limits::max(); + while (! candidates.empty()) { + HnswHit cand = candidates.top(); + if (cand.dist > limd) { + break; + } + candidates.pop(); + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + ++distcalls_search_layer; + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + if (blacklist.isSet(e_id)) continue; + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } +} + LinkList HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const { @@ -458,13 +553,13 @@ HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkL return result; } +#define NO_BACKFILL #ifdef NO_BACKFILL LinkList HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const { LinkList result; result.reserve(curMax+1); - bool needFiltering = (neighbors.size() > curMax); NearestPriQ w; for (const auto & entry : neighbors) { w.push(entry); @@ -472,12 +567,16 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con while (! w.empty()) { HnswHit e = w.top(); w.pop(); - if (needFiltering && haveCloserDistance(e, result)) { + if (haveCloserDistance(e, result)) { continue; } result.push_back(e.docid); - if (result.size() == curMax) return result; + if (result.size() == curMax) { + ++select_n_full; + return result; + } } + ++select_n_partial; return result; } #else @@ -502,10 +601,10 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con result.push_back(e.docid); if (result.size() == curMax) return result; } - if (result.size() * 4 < curMax) { + if (result.size() * 4 < _M) { for (uint32_t fill_id : backfill) { result.push_back(fill_id); - if (result.size() * 4 >= curMax) break; + if (result.size() * 2 >= _M) break; } } return result; @@ -576,7 +675,9 @@ HnswLikeNns::dumpStats() const { for (uint32_t n_id : link_list) { const LinkList &neigh_list = getLinkList(n_id, 0); if (! neigh_list.has_link_to(id)) { +#ifdef KEEP_SYM fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id); +#endif all_sym = false; } } diff --git a/eval/src/tests/ann/xp-lsh-nns.cpp b/eval/src/tests/ann/xp-lsh-nns.cpp index 0ea119a9c70..c028a07a9d7 100644 --- a/eval/src/tests/ann/xp-lsh-nns.cpp +++ b/eval/src/tests/ann/xp-lsh-nns.cpp @@ -118,6 +118,7 @@ public: } } std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override; + std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override; V getVector(uint32_t docid) const { return _dva.get(docid); } double uniformRnd() { return _rndGen.nextUniform(); } @@ -195,6 +196,45 @@ public: } }; +std::vector +RpLshNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + std::vector result; + result.reserve(k); + + std::vector tmp(_numDims); + vespalib::ArrayRef tmpArr(tmp); + + LsMaskHash query_hash = mask_hash_from_pv(vector, _transformationMatrix); + LshHitHeap heap(std::max(k, search_k)); + int limit_hash_dist = 99999; + int skipCnt = 0; + int fullCnt = 0; + int whdcCnt = 0; + size_t docidLimit = _generated_doc_hashes.size(); + for (uint32_t docid = 0; docid < docidLimit; ++docid) { + if (blacklist.isSet(docid)) continue; + int hd = hash_dist(query_hash, _generated_doc_hashes[docid]); + if (hd <= limit_hash_dist) { + ++fullCnt; + double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid), tmpArr); + LshHit h(docid, dist, hd); + if (heap.maybe_use(h)) { + ++whdcCnt; + limit_hash_dist = heap.limitHashDistance(); + } + } else { + ++skipCnt; + } + } + std::vector best = heap.bestLshHits(); + size_t numHits = std::min((size_t)k, best.size()); + for (size_t i = 0; i < numHits; ++i) { + result.emplace_back(best[i].docid, SqDist(best[i].distance)); + } + return result; +} + std::vector RpLshNns::topK(uint32_t k, Vector vector, uint32_t search_k) { -- cgit v1.2.3 From 44fef3325d3e9bfa673d71b87721f0979f8404c8 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 25 Feb 2020 14:11:00 +0000 Subject: split out common subroutines --- eval/src/tests/ann/bruteforce-nns.h | 74 ++++++++++++++ eval/src/tests/ann/gist_benchmark.cpp | 181 +++------------------------------ eval/src/tests/ann/point-vector.h | 30 ++++++ eval/src/tests/ann/read-vecs.h | 45 +++++++++ eval/src/tests/ann/remove-bm.cpp | 182 +++------------------------------- eval/src/tests/ann/sift_benchmark.cpp | 160 +----------------------------- eval/src/tests/ann/time-util.h | 9 ++ 7 files changed, 190 insertions(+), 491 deletions(-) create mode 100644 eval/src/tests/ann/bruteforce-nns.h create mode 100644 eval/src/tests/ann/point-vector.h create mode 100644 eval/src/tests/ann/read-vecs.h create mode 100644 eval/src/tests/ann/time-util.h diff --git a/eval/src/tests/ann/bruteforce-nns.h b/eval/src/tests/ann/bruteforce-nns.h new file mode 100644 index 00000000000..0c7c48654f7 --- /dev/null +++ b/eval/src/tests/ann/bruteforce-nns.h @@ -0,0 +1,74 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +std::vector bruteforceResults; + +double computeDistance(const PointVector &query, uint32_t docid) { + const PointVector &docvector = generatedDocs[docid]; + return l2distCalc.l2sq_dist(query, docvector); +} + +struct BfHitComparator { + bool operator() (const Hit &lhs, const Hit& rhs) const { + if (lhs.distance < rhs.distance) return false; + if (lhs.distance > rhs.distance) return true; + return (lhs.docid > rhs.docid); + } +}; + +class BfHitHeap { +private: + size_t _size; + vespalib::PriorityQueue _priQ; +public: + explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() { + _priQ.reserve(maxSize); + } + ~BfHitHeap() {} + void maybe_use(const Hit &hit) { + if (_priQ.size() < _size) { + _priQ.push(hit); + } else if (hit.distance < _priQ.front().distance) { + _priQ.front() = hit; + _priQ.adjust(); + } + } + std::vector bestHits() { + std::vector result; + size_t i = _priQ.size(); + result.resize(i); + while (i-- > 0) { + result[i] = _priQ.front(); + _priQ.pop_front(); + } + return result; + } +}; + +TopK bruteforce_nns(const PointVector &query) { + TopK result; + BfHitHeap heap(result.K); + for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { + const PointVector &docvector = generatedDocs[docid]; + double d = l2distCalc.l2sq_dist(query, docvector); + Hit h(docid, d); + heap.maybe_use(h); + } + std::vector best = heap.bestHits(); + for (size_t i = 0; i < result.K; ++i) { + result.hits[i] = best[i]; + } + return result; +} + +void verifyBF(uint32_t qid) { + const PointVector &query = generatedQueries[qid]; + TopK &result = bruteforceResults[qid]; + double min_distance = result.hits[0].distance; + for (uint32_t i = 0; i < NUM_DOCS; ++i) { + double dist = computeDistance(query, i); + if (dist < min_distance) { + fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance); + } + EXPECT_FALSE(dist+0.000001 < min_distance); + } +} diff --git a/eval/src/tests/ann/gist_benchmark.cpp b/eval/src/tests/ann/gist_benchmark.cpp index 45559fc2557..de8bff877e6 100644 --- a/eval/src/tests/ann/gist_benchmark.cpp +++ b/eval/src/tests/ann/gist_benchmark.cpp @@ -18,167 +18,10 @@ #include "nns.h" #include "for-sift-hit.h" #include "for-sift-top-k.h" - -std::vector bruteforceResults; - -struct PointVector { - float v[NUM_DIMS]; - using ConstArr = vespalib::ConstArrayRef; - operator ConstArr() const { return ConstArr(v, NUM_DIMS); } -}; - -static PointVector *aligned_alloc(size_t num) { - size_t num_bytes = num * sizeof(PointVector); - double mega_bytes = num_bytes / (1024.0*1024.0); - fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); - char *mem = (char *)malloc(num_bytes + 512); - mem += 512; - size_t val = (size_t)mem; - size_t unalign = val % 512; - mem -= unalign; - return reinterpret_cast(mem); -} - -static PointVector *generatedQueries = aligned_alloc(NUM_Q); -static PointVector *generatedDocs = aligned_alloc(NUM_DOCS); - -struct DocVectorAdapter : public DocVectorAccess -{ - vespalib::ConstArrayRef get(uint32_t docid) const override { - ASSERT_TRUE(docid < NUM_DOCS); - return generatedDocs[docid]; - } -}; - -double computeDistance(const PointVector &query, uint32_t docid) { - const PointVector &docvector = generatedDocs[docid]; - return l2distCalc.l2sq_dist(query, docvector); -} - -void read_queries(std::string fn) { - int fd = open(fn.c_str(), O_RDONLY); - ASSERT_TRUE(fd > 0); - int d; - size_t rv; - fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str()); - for (uint32_t qid = 0; qid < NUM_Q; ++qid) { - rv = read(fd, &d, 4); - ASSERT_EQUAL(rv, 4u); - ASSERT_EQUAL(d, NUM_DIMS); - rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float)); - ASSERT_EQUAL(rv, sizeof(PointVector)); - } - close(fd); -} - -void read_docs(std::string fn) { - int fd = open(fn.c_str(), O_RDONLY); - ASSERT_TRUE(fd > 0); - int d; - size_t rv; - fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str()); - for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { - rv = read(fd, &d, 4); - ASSERT_EQUAL(rv, 4u); - ASSERT_EQUAL(d, NUM_DIMS); - rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float)); - ASSERT_EQUAL(rv, sizeof(PointVector)); - } - close(fd); -} - -using TimePoint = std::chrono::steady_clock::time_point; -using Duration = std::chrono::steady_clock::duration; - -double to_ms(Duration elapsed) { - std::chrono::duration ms(elapsed); - return ms.count(); -} - -void read_data(std::string dir) { - TimePoint bef = std::chrono::steady_clock::now(); - read_queries(dir + "/gist_query.fvecs"); - TimePoint aft = std::chrono::steady_clock::now(); - fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef)); - bef = std::chrono::steady_clock::now(); - read_docs(dir + "/gist_base.fvecs"); - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef)); -} - - -struct BfHitComparator { - bool operator() (const Hit &lhs, const Hit& rhs) const { - if (lhs.distance < rhs.distance) return false; - if (lhs.distance > rhs.distance) return true; - return (lhs.docid > rhs.docid); - } -}; - -class BfHitHeap { -private: - size_t _size; - vespalib::PriorityQueue _priQ; -public: - explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() { - _priQ.reserve(maxSize); - } - ~BfHitHeap() {} - void maybe_use(const Hit &hit) { - if (_priQ.size() < _size) { - _priQ.push(hit); - } else if (hit.distance < _priQ.front().distance) { - _priQ.front() = hit; - _priQ.adjust(); - } - } - std::vector bestHits() { - std::vector result; - size_t i = _priQ.size(); - result.resize(i); - while (i-- > 0) { - result[i] = _priQ.front(); - _priQ.pop_front(); - } - return result; - } -}; - -TopK bruteforce_nns(const PointVector &query) { - TopK result; - BfHitHeap heap(result.K); - for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { - const PointVector &docvector = generatedDocs[docid]; - double d = l2distCalc.l2sq_dist(query, docvector); - Hit h(docid, d); - heap.maybe_use(h); - } - std::vector best = heap.bestHits(); - for (size_t i = 0; i < result.K; ++i) { - result.hits[i] = best[i]; - } - return result; -} - -void verifyBF(uint32_t qid) { - const PointVector &query = generatedQueries[qid]; - TopK &result = bruteforceResults[qid]; - double min_distance = result.hits[0].distance; - std::vector all_c2; - for (uint32_t i = 0; i < NUM_DOCS; ++i) { - double dist = computeDistance(query, i); - if (dist < min_distance) { - fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance); - } - EXPECT_FALSE(dist+0.000001 < min_distance); - if (min_distance > 0.0) all_c2.push_back(dist / min_distance); - } - if (all_c2.size() != NUM_DOCS) return; - std::sort(all_c2.begin(), all_c2.end()); - for (uint32_t idx : { 1, 3, 10, 30, 100, 300, 1000, 3000, NUM_DOCS/2, NUM_DOCS-1}) { - fprintf(stderr, "c2-factor[%u] = %.3f\n", idx, all_c2[idx]); - } -} +#include "time-util.h" +#include "point-vector.h" +#include "read-vecs.h" +#include "bruteforce-nns.h" using NNS_API = NNS; @@ -279,17 +122,21 @@ TEST("require that HNSW wrapped api mostly works") { */ int main(int argc, char **argv) { TEST_MASTER.init(__FILE__); - std::string gist_dir = "."; - if (argc > 1) { - gist_dir = argv[1]; + std::string data_set = "gist"; + std::string data_dir = "."; + if (argc > 2) { + data_set = argv[1]; + data_dir = argv[2]; + } else if (argc > 1) { + data_dir = argv[1]; } else { char *home = getenv("HOME"); if (home) { - gist_dir = home; - gist_dir += "/gist"; + data_dir = home; + data_dir += "/" + data_set; } } - read_data(gist_dir); + read_data(data_dir, data_set); TEST_RUN_ALL(); return (TEST_MASTER.fini() ? 0 : 1); } diff --git a/eval/src/tests/ann/point-vector.h b/eval/src/tests/ann/point-vector.h new file mode 100644 index 00000000000..eca60e11194 --- /dev/null +++ b/eval/src/tests/ann/point-vector.h @@ -0,0 +1,30 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +struct PointVector { + float v[NUM_DIMS]; + using ConstArr = vespalib::ConstArrayRef; + operator ConstArr() const { return ConstArr(v, NUM_DIMS); } +}; + +static PointVector *aligned_alloc(size_t num) { + size_t num_bytes = num * sizeof(PointVector); + double mega_bytes = num_bytes / (1024.0*1024.0); + fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); + char *mem = (char *)malloc(num_bytes + 512); + mem += 512; + size_t val = (size_t)mem; + size_t unalign = val % 512; + mem -= unalign; + return reinterpret_cast(mem); +} + +static PointVector *generatedQueries = aligned_alloc(NUM_Q); +static PointVector *generatedDocs = aligned_alloc(NUM_DOCS); + +struct DocVectorAdapter : public DocVectorAccess +{ + vespalib::ConstArrayRef get(uint32_t docid) const override { + ASSERT_TRUE(docid < NUM_DOCS); + return generatedDocs[docid]; + } +}; diff --git a/eval/src/tests/ann/read-vecs.h b/eval/src/tests/ann/read-vecs.h new file mode 100644 index 00000000000..39c2a332710 --- /dev/null +++ b/eval/src/tests/ann/read-vecs.h @@ -0,0 +1,45 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +void read_queries(std::string fn) { + int fd = open(fn.c_str(), O_RDONLY); + ASSERT_TRUE(fd > 0); + int d; + size_t rv; + fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str()); + for (uint32_t qid = 0; qid < NUM_Q; ++qid) { + rv = read(fd, &d, 4); + ASSERT_EQUAL(rv, 4u); + ASSERT_EQUAL(d, NUM_DIMS); + rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float)); + ASSERT_EQUAL(rv, sizeof(PointVector)); + } + close(fd); +} + +void read_docs(std::string fn) { + int fd = open(fn.c_str(), O_RDONLY); + ASSERT_TRUE(fd > 0); + int d; + size_t rv; + fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str()); + for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { + rv = read(fd, &d, 4); + ASSERT_EQUAL(rv, 4u); + ASSERT_EQUAL(d, NUM_DIMS); + rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float)); + ASSERT_EQUAL(rv, sizeof(PointVector)); + } + close(fd); +} + +void read_data(const std::string& dir, const std::string& data_set) { + fprintf(stderr, "read data set '%s' from directory '%s'\n", data_set.c_str(), dir.c_str()); + TimePoint bef = std::chrono::steady_clock::now(); + read_queries(dir + "/" + data_set + "_query.fvecs"); + TimePoint aft = std::chrono::steady_clock::now(); + fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef)); + bef = std::chrono::steady_clock::now(); + read_docs(dir + "/" + data_set + "_base.fvecs"); + aft = std::chrono::steady_clock::now(); + fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef)); +} diff --git a/eval/src/tests/ann/remove-bm.cpp b/eval/src/tests/ann/remove-bm.cpp index 005f3804af9..546c2cfd75e 100644 --- a/eval/src/tests/ann/remove-bm.cpp +++ b/eval/src/tests/ann/remove-bm.cpp @@ -20,168 +20,10 @@ #include "nns.h" #include "for-sift-hit.h" #include "for-sift-top-k.h" - -std::vector bruteforceResults; -std::vector tmp_v(NUM_DIMS); - -struct PointVector { - float v[NUM_DIMS]; - using ConstArr = vespalib::ConstArrayRef; - operator ConstArr() const { return ConstArr(v, NUM_DIMS); } -}; - -static PointVector *aligned_alloc(size_t num) { - size_t num_bytes = num * sizeof(PointVector); - double mega_bytes = num_bytes / (1024.0*1024.0); - fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); - char *mem = (char *)malloc(num_bytes + 512); - mem += 512; - size_t val = (size_t)mem; - size_t unalign = val % 512; - mem -= unalign; - return reinterpret_cast(mem); -} - -static PointVector *generatedQueries = aligned_alloc(NUM_Q); -static PointVector *generatedDocs = aligned_alloc(NUM_DOCS); - -struct DocVectorAdapter : public DocVectorAccess -{ - vespalib::ConstArrayRef get(uint32_t docid) const override { - ASSERT_TRUE(docid < NUM_DOCS); - return generatedDocs[docid]; - } -}; - -double computeDistance(const PointVector &query, uint32_t docid) { - const PointVector &docvector = generatedDocs[docid]; - return l2distCalc.l2sq_dist(query, docvector, tmp_v); -} - -void read_queries(std::string fn) { - int fd = open(fn.c_str(), O_RDONLY); - ASSERT_TRUE(fd > 0); - int d; - size_t rv; - fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str()); - for (uint32_t qid = 0; qid < NUM_Q; ++qid) { - rv = read(fd, &d, 4); - ASSERT_EQUAL(rv, 4u); - ASSERT_EQUAL(d, NUM_DIMS); - rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float)); - ASSERT_EQUAL(rv, sizeof(PointVector)); - } - close(fd); -} - -void read_docs(std::string fn) { - int fd = open(fn.c_str(), O_RDONLY); - ASSERT_TRUE(fd > 0); - int d; - size_t rv; - fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str()); - for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { - rv = read(fd, &d, 4); - ASSERT_EQUAL(rv, 4u); - ASSERT_EQUAL(d, NUM_DIMS); - rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float)); - ASSERT_EQUAL(rv, sizeof(PointVector)); - } - close(fd); -} - -using TimePoint = std::chrono::steady_clock::time_point; -using Duration = std::chrono::steady_clock::duration; - -double to_ms(Duration elapsed) { - std::chrono::duration ms(elapsed); - return ms.count(); -} - -void read_data(std::string dir) { - TimePoint bef = std::chrono::steady_clock::now(); - read_queries(dir + "/gist_query.fvecs"); - TimePoint aft = std::chrono::steady_clock::now(); - fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef)); - bef = std::chrono::steady_clock::now(); - read_docs(dir + "/gist_base.fvecs"); - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef)); -} - - -struct BfHitComparator { - bool operator() (const Hit &lhs, const Hit& rhs) const { - if (lhs.distance < rhs.distance) return false; - if (lhs.distance > rhs.distance) return true; - return (lhs.docid > rhs.docid); - } -}; - -class BfHitHeap { -private: - size_t _size; - vespalib::PriorityQueue _priQ; -public: - explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() { - _priQ.reserve(maxSize); - } - ~BfHitHeap() {} - void maybe_use(const Hit &hit) { - if (_priQ.size() < _size) { - _priQ.push(hit); - } else if (hit.distance < _priQ.front().distance) { - _priQ.front() = hit; - _priQ.adjust(); - } - } - std::vector bestHits() { - std::vector result; - size_t i = _priQ.size(); - result.resize(i); - while (i-- > 0) { - result[i] = _priQ.front(); - _priQ.pop_front(); - } - return result; - } -}; - -TopK bruteforce_nns(const PointVector &query) { - TopK result; - BfHitHeap heap(result.K); - for (uint32_t docid = 0; docid < EFFECTIVE_DOCS; ++docid) { - const PointVector &docvector = generatedDocs[docid]; - double d = l2distCalc.l2sq_dist(query, docvector, tmp_v); - Hit h(docid, d); - heap.maybe_use(h); - } - std::vector best = heap.bestHits(); - for (size_t i = 0; i < result.K; ++i) { - result.hits[i] = best[i]; - } - return result; -} - -void verifyBF(uint32_t qid) { - const PointVector &query = generatedQueries[qid]; - TopK &result = bruteforceResults[qid]; - double min_distance = result.hits[0].distance; - std::vector all_c2; - for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) { - double dist = computeDistance(query, i); - if (dist < min_distance) { - fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance); - } - EXPECT_FALSE(dist+0.000001 < min_distance); - if (min_distance > 0.0) all_c2.push_back(dist / min_distance); - } - if (all_c2.size() != EFFECTIVE_DOCS) return; - std::sort(all_c2.begin(), all_c2.end()); - for (uint32_t idx : { 1, 3, 10, 30, 100, 300, 1000, 3000, EFFECTIVE_DOCS/2, EFFECTIVE_DOCS-1}) { - fprintf(stderr, "c2-factor[%u] = %.3f\n", idx, all_c2[idx]); - } -} +#include "time-util.h" +#include "point-vector.h" +#include "read-vecs.h" +#include "bruteforce-nns.h" using NNS_API = NNS; @@ -386,17 +228,21 @@ TEST("require that HNSW wrapped api mostly works") { */ int main(int argc, char **argv) { TEST_MASTER.init(__FILE__); - std::string gist_dir = "."; - if (argc > 1) { - gist_dir = argv[1]; + std::string data_set = "gist"; + std::string data_dir = "."; + if (argc > 2) { + data_set = argv[1]; + data_dir = argv[2]; + } else if (argc > 1) { + data_dir = argv[1]; } else { char *home = getenv("HOME"); if (home) { - gist_dir = home; - gist_dir += "/gist"; + data_dir = home; + data_dir += "/" + data_set; } } - read_data(gist_dir); + read_data(data_dir, data_set); TEST_RUN_ALL(); return (TEST_MASTER.fini() ? 0 : 1); } diff --git a/eval/src/tests/ann/sift_benchmark.cpp b/eval/src/tests/ann/sift_benchmark.cpp index 5f3c16e127d..b2fa66cd0f1 100644 --- a/eval/src/tests/ann/sift_benchmark.cpp +++ b/eval/src/tests/ann/sift_benchmark.cpp @@ -20,148 +20,10 @@ #include "for-sift-hit.h" #include "for-sift-top-k.h" #include "std-random.h" - -std::vector bruteforceResults; - -struct PointVector { - float v[NUM_DIMS]; - using ConstArr = vespalib::ConstArrayRef; - operator ConstArr() const { return ConstArr(v, NUM_DIMS); } -}; - -static PointVector *aligned_alloc(size_t num) { - size_t num_bytes = num * sizeof(PointVector); - double mega_bytes = num_bytes / (1024.0*1024.0); - fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); - char *mem = (char *)malloc(num_bytes + 512); - mem += 512; - size_t val = (size_t)mem; - size_t unalign = val % 512; - mem -= unalign; - return reinterpret_cast(mem); -} - -static PointVector *generatedQueries = aligned_alloc(NUM_Q); -static PointVector *generatedDocs = aligned_alloc(NUM_DOCS); - -struct DocVectorAdapter : public DocVectorAccess -{ - vespalib::ConstArrayRef get(uint32_t docid) const override { - ASSERT_TRUE(docid < NUM_DOCS); - return generatedDocs[docid]; - } -}; - -double computeDistance(const PointVector &query, uint32_t docid) { - const PointVector &docvector = generatedDocs[docid]; - return l2distCalc.l2sq_dist(query, docvector); -} - -void read_queries(std::string fn) { - int fd = open(fn.c_str(), O_RDONLY); - ASSERT_TRUE(fd > 0); - int d; - size_t rv; - fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str()); - for (uint32_t qid = 0; qid < NUM_Q; ++qid) { - rv = read(fd, &d, 4); - ASSERT_EQUAL(rv, 4u); - ASSERT_EQUAL(d, NUM_DIMS); - rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float)); - ASSERT_EQUAL(rv, sizeof(PointVector)); - } - close(fd); -} - -void read_docs(std::string fn) { - int fd = open(fn.c_str(), O_RDONLY); - ASSERT_TRUE(fd > 0); - int d; - size_t rv; - fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str()); - for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { - rv = read(fd, &d, 4); - ASSERT_EQUAL(rv, 4u); - ASSERT_EQUAL(d, NUM_DIMS); - rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float)); - ASSERT_EQUAL(rv, sizeof(PointVector)); - } - close(fd); -} - -using TimePoint = std::chrono::steady_clock::time_point; -using Duration = std::chrono::steady_clock::duration; - -double to_ms(Duration elapsed) { - std::chrono::duration ms(elapsed); - return ms.count(); -} - -void read_data(const std::string& dir, const std::string& data_set) { - fprintf(stderr, "read data set '%s' from directory '%s'\n", data_set.c_str(), dir.c_str()); - TimePoint bef = std::chrono::steady_clock::now(); - read_queries(dir + "/" + data_set + "_query.fvecs"); - TimePoint aft = std::chrono::steady_clock::now(); - fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef)); - bef = std::chrono::steady_clock::now(); - read_docs(dir + "/" + data_set + "_base.fvecs"); - aft = std::chrono::steady_clock::now(); - fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef)); -} - - -struct BfHitComparator { - bool operator() (const Hit &lhs, const Hit& rhs) const { - if (lhs.distance < rhs.distance) return false; - if (lhs.distance > rhs.distance) return true; - return (lhs.docid > rhs.docid); - } -}; - -class BfHitHeap { -private: - size_t _size; - vespalib::PriorityQueue _priQ; -public: - explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() { - _priQ.reserve(maxSize); - } - ~BfHitHeap() {} - void maybe_use(const Hit &hit) { - if (_priQ.size() < _size) { - _priQ.push(hit); - } else if (hit.distance < _priQ.front().distance) { - _priQ.front() = hit; - _priQ.adjust(); - } - } - std::vector bestHits() { - std::vector result; - size_t i = _priQ.size(); - result.resize(i); - while (i-- > 0) { - result[i] = _priQ.front(); - _priQ.pop_front(); - } - return result; - } -}; - -TopK bruteforce_nns(const PointVector &query) { - TopK result; - BfHitHeap heap(result.K); - for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) { - const PointVector &docvector = generatedDocs[docid]; - double d = l2distCalc.l2sq_dist(query, docvector); - Hit h(docid, d); - heap.maybe_use(h); - } - std::vector best = heap.bestHits(); - for (size_t i = 0; i < result.K; ++i) { - result.hits[i] = best[i]; - } - return result; -} +#include "time-util.h" +#include "point-vector.h" +#include "read-vecs.h" +#include "bruteforce-nns.h" TopK bruteforce_nns_filter(const PointVector &query, const BitVector &blacklist) { TopK result; @@ -181,20 +43,6 @@ TopK bruteforce_nns_filter(const PointVector &query, const BitVector &blacklist) return result; } - -void verifyBF(uint32_t qid) { - const PointVector &query = generatedQueries[qid]; - TopK &result = bruteforceResults[qid]; - double min_distance = result.hits[0].distance; - for (uint32_t i = 0; i < NUM_DOCS; ++i) { - double dist = computeDistance(query, i); - if (dist < min_distance) { - fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance); - } - EXPECT_FALSE(dist+0.000001 < min_distance); - } -} - void timing_bf_filter(int percent) { BitVector blacklist(NUM_DOCS); diff --git a/eval/src/tests/ann/time-util.h b/eval/src/tests/ann/time-util.h new file mode 100644 index 00000000000..2f5c2bdd583 --- /dev/null +++ b/eval/src/tests/ann/time-util.h @@ -0,0 +1,9 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +using TimePoint = std::chrono::steady_clock::time_point; +using Duration = std::chrono::steady_clock::duration; + +double to_ms(Duration elapsed) { + std::chrono::duration ms(elapsed); + return ms.count(); +} -- cgit v1.2.3 From 16c477ad557d556fe4d63c871025f10b18aba84d Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 25 Feb 2020 14:40:02 +0000 Subject: keep more code common --- eval/src/tests/ann/CMakeLists.txt | 2 +- eval/src/tests/ann/extended-hnsw.cpp | 686 ++++++++++++--------------------- eval/src/tests/ann/xp-hnswlike-nns.cpp | 527 +++++++++---------------- 3 files changed, 427 insertions(+), 788 deletions(-) diff --git a/eval/src/tests/ann/CMakeLists.txt b/eval/src/tests/ann/CMakeLists.txt index 34babf1412f..0ba38994c01 100644 --- a/eval/src/tests/ann/CMakeLists.txt +++ b/eval/src/tests/ann/CMakeLists.txt @@ -14,7 +14,7 @@ vespa_add_executable(eval_gist_benchmark_app SOURCES gist_benchmark.cpp xp-annoy-nns.cpp - xp-hnswlike-nns.cpp + extended-hnsw.cpp xp-lsh-nns.cpp DEPENDS vespaeval diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp index 42f3a10b389..fbc4bedec05 100644 --- a/eval/src/tests/ann/extended-hnsw.cpp +++ b/eval/src/tests/ann/extended-hnsw.cpp @@ -1,29 +1,6 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include -#include -#include -#include -#include "std-random.h" -#include "nns.h" - -/* - Todo: - - measure effect of: - 1) removing leftover backlinks during "shrink" operation - 2) refilling to low-watermark after 1) happens - 3) refilling to mid-watermark after 1) happens - 4) adding then removing 20% extra documents - 5) removing 20% first-added documents - 6) removing first-added documents while inserting new ones - - 7) auto-tune search_k to ensure >= 50% recall on 1000 Q with k=100 - 8) auto-tune search_k to ensure avg 90% recall on 1000 Q with k=100 - 9) auto-tune search_k to ensure >= 90% reachability of 10000 docids - - 10) timings for SIFT, GIST, and DEEP data (100k, 200k, 300k, 500k, 700k, 1000k) - */ +#include "hnsw-like.h" static size_t distcalls_simple; static size_t distcalls_search_layer; @@ -38,471 +15,300 @@ static size_t disconnected_for_symmetry; static size_t select_n_full; static size_t select_n_partial; -struct LinkList : std::vector -{ - bool has_link_to(uint32_t id) const { - auto iter = std::find(begin(), end(), id); - return (iter != end()); - } - void remove_link(uint32_t id) { - uint32_t last = back(); - for (iterator iter = begin(); iter != end(); ++iter) { - if (*iter == id) { - *iter = last; - pop_back(); - return; - } - } - fprintf(stderr, "BAD missing link to remove: %u\n", id); - abort(); - } -}; - -struct Node { - std::vector _links; - Node(uint32_t , uint32_t numLevels, uint32_t M) - : _links(numLevels) - { - for (uint32_t i = 0; i < _links.size(); ++i) { - _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); - } - } -}; - -struct VisitedSet -{ - using Mark = unsigned short; - Mark *ptr; - Mark curval; - size_t sz; - VisitedSet(const VisitedSet &) = delete; - VisitedSet& operator=(const VisitedSet &) = delete; - explicit VisitedSet(size_t size) { - ptr = (Mark *)malloc(size * sizeof(Mark)); - curval = -1; - sz = size; - clear(); - } - void clear() { - ++curval; - if (curval == 0) { - memset(ptr, 0, sz * sizeof(Mark)); - ++curval; - } - } - ~VisitedSet() { free(ptr); } - void mark(size_t id) { ptr[id] = curval; } - bool isMarked(size_t id) const { return ptr[id] == curval; } -}; -struct VisitedSetPool +HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(200), + _levelMultiplier(1.0 / log(1.0 * _M)), + _rndGen(), + _ops_counter(0) { - std::unique_ptr lastUsed; - VisitedSetPool() { - lastUsed = std::make_unique(250); - } - ~VisitedSetPool() {} - VisitedSet &get(size_t size) { - if (size > lastUsed->sz) { - lastUsed = std::make_unique(size*2); - } else { - lastUsed->clear(); - } - return *lastUsed; - } -}; - -struct HnswHit { - double dist; - uint32_t docid; - HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} -}; - -struct GreaterDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (rhs.dist < lhs.dist); - } -}; -struct LesserDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (lhs.dist < rhs.dist); - } -}; - -using NearestList = std::vector; - -struct NearestPriQ : std::priority_queue -{ -}; - -struct FurthestPriQ : std::priority_queue -{ - NearestList steal() { - NearestList result; - c.swap(result); - return result; - } - const NearestList& peek() const { return c; } -}; - -class HnswLikeNns : public NNS -{ -private: - std::vector _nodes; - uint32_t _entryId; - int _entryLevel; - uint32_t _M; - uint32_t _efConstruction; - double _levelMultiplier; - RndGen _rndGen; - VisitedSetPool _visitedSetPool; - size_t _ops_counter; - - double distance(Vector v, uint32_t id) const; - - double distance(uint32_t a, uint32_t b) const { - Vector v = _dva.get(a); - return distance(v, b); - } - - int randomLevel() { - double unif = _rndGen.nextUniform(); - double r = -log(1.0-unif) * _levelMultiplier; - return (int) r; - } - - uint32_t count_reachable() const; - void dumpStats() const; - -public: - HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva) - : NNS(numDims, dva), - _nodes(), - _entryId(0), - _entryLevel(-1), - _M(16), - _efConstruction(200), - _levelMultiplier(1.0 / log(1.0 * _M)), - _rndGen(), - _ops_counter(0) - { - } - - ~HnswLikeNns() { dumpStats(); } - - LinkList& getLinkList(uint32_t docid, uint32_t level) { - // assert(docid < _nodes.size()); - // assert(level < _nodes[docid]._links.size()); - return _nodes[docid]._links[level]; - } - - const LinkList& getLinkList(uint32_t docid, uint32_t level) const { - return _nodes[docid]._links[level]; - } +} - // simple greedy search - HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { - bool keepGoing = true; - while (keepGoing) { - keepGoing = false; - const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); - for (uint32_t n_id : neighbors) { - double dist = distance(vector, n_id); - ++distcalls_simple; - if (dist < curPoint.dist) { - curPoint = HnswHit(n_id, SqDist(dist)); - keepGoing = true; - } +// simple greedy search +HnswHit +HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + ++distcalls_simple; + if (dist < curPoint.dist) { + curPoint = HnswHit(n_id, SqDist(dist)); + keepGoing = true; } } - return curPoint; } + return curPoint; +} - void search_layer(Vector vector, FurthestPriQ &w, - VisitedSet &visited, - uint32_t ef, uint32_t searchLevel); - - void search_layer_with_filter(Vector vector, FurthestPriQ &w, - VisitedSet &visited, - uint32_t ef, uint32_t searchLevel, - const BitVector &blacklist); - bool haveCloserDistance(HnswHit e, const LinkList &r) const { - for (uint32_t prevId : r) { - double dist = distance(e.docid, prevId); - ++distcalls_heuristic; - if (dist < e.dist) return true; - } - return false; +bool +HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + ++distcalls_heuristic; + if (dist < e.dist) return true; } + return false; +} - LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; - - LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; - - void addDoc(uint32_t docid) override { - Vector vector = _dva.get(docid); - for (uint32_t id = _nodes.size(); id <= docid; ++id) { - _nodes.emplace_back(id, 0, _M); - } - int level = randomLevel(); - assert(_nodes[docid]._links.size() == 0); - _nodes[docid] = Node(docid, level+1, _M); - if (_entryLevel < 0) { - _entryId = docid; - _entryLevel = level; - track_ops(); - return; - } - int searchLevel = _entryLevel; - VisitedSet &visited = _visitedSetPool.get(_nodes.size()); - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); +void +HnswLikeNns::addDoc(uint32_t docid) { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + track_ops(); + return; + } + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); #undef MULTI_ENTRY_I #ifdef MULTI_ENTRY_I - FurthestPriQ w; - w.push(entryPoint); - while (searchLevel > level) { - search_layer(vector, w, visited, 5 * _M, searchLevel); - --searchLevel; - } + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > level) { + search_layer(vector, w, visited, 5 * _M, searchLevel); + --searchLevel; + } #else - while (searchLevel > level) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - FurthestPriQ w; - w.push(entryPoint); + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); #endif - searchLevel = std::min(level, _entryLevel); - while (searchLevel >= 0) { - search_layer(vector, w, visited, _efConstruction, searchLevel); - LinkList neighbors = select_neighbors(w.peek(), _M); - connect_new_node(docid, neighbors, searchLevel); - each_shrink_ifneeded(neighbors, searchLevel); - --searchLevel; - } - if (level > _entryLevel) { - _entryLevel = level; - _entryId = docid; - } - track_ops(); + searchLevel = std::min(level, _entryLevel); + while (searchLevel >= 0) { + search_layer(vector, w, visited, _efConstruction, searchLevel); + LinkList neighbors = select_neighbors(w.peek(), _M); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; } - - void track_ops() { - _ops_counter++; - if ((_ops_counter % 10000) == 0) { - double div = _ops_counter; - fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); - fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); - fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); - fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); - fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); - fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); - fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); - fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); - fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); - fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); - fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); - fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); - } - } - - void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { - LinkList &links = getLinkList(from_id, level); - links.remove_link(remove_id); + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; } + track_ops(); +} + +void +HnswLikeNns::track_ops() { + _ops_counter++; + if ((_ops_counter % 10000) == 0) { + double div = _ops_counter; + fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); + fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); + } +} -#undef SIMPLE_REFILL #ifdef SIMPLE_REFILL - void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - if (my_links.size() * 2 < _M) { - const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); - ++refill_needed_calls; - for (uint32_t repl_id : replacements) { - if (repl_id == my_id) continue; - if (my_links.has_link_to(repl_id)) continue; - LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() >= maxLinks) continue; - other_links.push_back(my_id); - my_links.push_back(repl_id); - if (my_links.size() >= _M) return; - } - } - } -#else - void refill_all(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() * 2 < _M) { const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); - NearestPriQ w; + ++refill_needed_calls; for (uint32_t repl_id : replacements) { if (repl_id == my_id) continue; if (my_links.has_link_to(repl_id)) continue; - const LinkList &other_links = getLinkList(repl_id, level); + LinkList &other_links = getLinkList(repl_id, level); if (other_links.size() >= maxLinks) continue; - double dist = distance(my_id, repl_id); - ++distcalls_refill; - w.emplace(repl_id, SqDist(dist)); - } - while (! w.empty()) { - HnswHit e = w.top(); - w.pop(); - if (haveCloserDistance(e, my_links)) continue; - LinkList &other_links = getLinkList(e.docid, level); - my_links.push_back(e.docid); other_links.push_back(my_id); - if (my_links.size() == _M) break; + my_links.push_back(repl_id); + if (my_links.size() >= _M) return; } } - void refill_one(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - NearestPriQ w; - for (uint32_t repl_id : replacements) { - if (repl_id == my_id) continue; - if (my_links.has_link_to(repl_id)) continue; - LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() >= _M) continue; - double dist = distance(my_id, repl_id); - ++distcalls_refill; - w.emplace(repl_id, SqDist(dist)); - } - while (! w.empty()) { - HnswHit e = w.top(); - w.pop(); - if (haveCloserDistance(e, my_links)) continue; - LinkList &other_links = getLinkList(e.docid, level); - my_links.push_back(e.docid); - other_links.push_back(my_id); - return; - } +} +#endif + +#define REFILL_ALL +#ifdef REFILL_ALL +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() >= _M) return; + ++refill_needed_calls; + const uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + const LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= maxLinks) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); } - void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - if (my_links.size() < _M) { - ++refill_needed_calls; - refill_all(my_id, replacements, level); - } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + if (my_links.size() == _M) break; } +} #endif - void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); +#ifdef REFILL_ONE +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() >= _M) return; + ++refill_needed_calls; + NearestPriQ w; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() >= _M) continue; + double dist = distance(my_id, repl_id); + ++distcalls_refill; + w.emplace(repl_id, SqDist(dist)); + } + while (! w.empty()) { + HnswHit e = w.top(); + w.pop(); + if (haveCloserDistance(e, my_links)) continue; + LinkList &other_links = getLinkList(e.docid, level); + my_links.push_back(e.docid); + other_links.push_back(my_id); + return; + } +} +#endif - void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { - LinkList &links = getLinkList(shrink_id, level); - NearestList distances; - for (uint32_t n_id : links) { - double n_dist = distance(shrink_id, n_id); - ++distcalls_shrink; - distances.emplace_back(n_id, SqDist(n_dist)); - } - LinkList lostLinks; - LinkList oldLinks = links; - links = remove_weakest(distances, maxLinks, lostLinks); +void +HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { + LinkList &links = getLinkList(shrink_id, level); + NearestList distances; + for (uint32_t n_id : links) { + double n_dist = distance(shrink_id, n_id); + ++distcalls_shrink; + distances.emplace_back(n_id, SqDist(n_dist)); + } + LinkList lostLinks; + LinkList oldLinks = links; + links = remove_weakest(distances, maxLinks, lostLinks); #define KEEP_SYM #ifdef KEEP_SYM - for (uint32_t lost_id : lostLinks) { - ++disconnected_for_symmetry; - remove_link_from(lost_id, shrink_id, level); - } + for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; + remove_link_from(lost_id, shrink_id, level); + } #define DO_REFILL_AFTER_KEEP_SYM #ifdef DO_REFILL_AFTER_KEEP_SYM - for (uint32_t lost_id : lostLinks) { -#ifdef SIMPLE_REFILL - refill_ifneeded(lost_id, oldLinks, level); -#else - refill_all(lost_id, oldLinks, level); -#endif - } + for (uint32_t lost_id : lostLinks) { + refill_ifneeded(lost_id, oldLinks, level); + } #endif #endif - } +} - void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); - void mutually_reconnect(LinkList cluster, int level) { - while (! cluster.empty()) { - uint32_t n_id = cluster.back(); - cluster.pop_back(); -#ifdef SIMPLE_REFILL - refill_ifneeded(n_id, cluster, level); -#else - refill_all(n_id, cluster, level); -#endif - } - } - - void removeDoc(uint32_t docid) override { - Node &node = _nodes[docid]; - bool need_new_entrypoint = (docid == _entryId); - for (int level = node._links.size(); level-- > 0; ) { - LinkList my_links; - my_links.swap(node._links[level]); - for (uint32_t n_id : my_links) { - if (need_new_entrypoint) { - _entryId = n_id; - _entryLevel = level; - need_new_entrypoint = false; - } - remove_link_from(n_id, docid, level); +void +HnswLikeNns::removeDoc(uint32_t docid) { + Node &node = _nodes[docid]; + bool need_new_entrypoint = (docid == _entryId); + for (int level = node._links.size(); level-- > 0; ) { + LinkList my_links; + my_links.swap(node._links[level]); + for (uint32_t n_id : my_links) { + if (need_new_entrypoint) { + _entryId = n_id; + _entryLevel = level; + need_new_entrypoint = false; } - mutually_reconnect(my_links, level); - } - node = Node(docid, 0, _M); - if (need_new_entrypoint) { - _entryLevel = -1; - _entryId = 0; - for (uint32_t i = 0; i < _nodes.size(); ++i) { - if (_nodes[i]._links.size() > 0) { - _entryId = i; - _entryLevel = _nodes[i]._links.size() - 1; - break; - } + remove_link_from(n_id, docid, level); + } + while (! my_links.empty()) { + uint32_t n_id = my_links.back(); + my_links.pop_back(); + refill_ifneeded(n_id, my_links, level); + } + } + node = Node(docid, 0, _M); + if (need_new_entrypoint) { + _entryLevel = -1; + _entryId = 0; + for (uint32_t i = 0; i < _nodes.size(); ++i) { + if (_nodes[i]._links.size() > 0) { + _entryId = i; + _entryLevel = _nodes[i]._links.size() - 1; + break; } } - track_ops(); } + track_ops(); +} - std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override { - std::vector result; - if (_entryLevel < 0) return result; - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); - int searchLevel = _entryLevel; - VisitedSet &visited = _visitedSetPool.get(_nodes.size()); +std::vector +HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) { + std::vector result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); #undef MULTI_ENTRY_S #ifdef MULTI_ENTRY_S - FurthestPriQ w; - w.push(entryPoint); - while (searchLevel > 0) { - search_layer(vector, w, visited, std::min(k, search_k), searchLevel); - --searchLevel; - } + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel > 0) { + search_layer(vector, w, visited, std::min(k, search_k), searchLevel); + --searchLevel; + } #else - while (searchLevel > 0) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - FurthestPriQ w; - w.push(entryPoint); + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); #endif - search_layer(vector, w, visited, std::max(k, search_k), 0); - while (w.size() > k) { - w.pop(); - } - NearestList tmp = w.steal(); - std::sort(tmp.begin(), tmp.end(), LesserDist()); - result.reserve(tmp.size()); - for (const auto & hit : tmp) { - result.emplace_back(hit.docid, SqDist(hit.dist)); - } - return result; + search_layer(vector, w, visited, std::max(k, search_k), 0); + while (w.size() > k) { + w.pop(); } - - std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; -}; + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(tmp.size()); + for (const auto & hit : tmp) { + result.emplace_back(hit.docid, SqDist(hit.dist)); + } + return result; +} double diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp index 90fc0fe2e92..b7cae9f731c 100644 --- a/eval/src/tests/ann/xp-hnswlike-nns.cpp +++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp @@ -1,11 +1,6 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include -#include -#include -#include -#include "std-random.h" -#include "nns.h" +#include "hnsw-like.h" /* Todo: @@ -38,378 +33,216 @@ static size_t disconnected_for_symmetry; static size_t select_n_full; static size_t select_n_partial; -struct LinkList : std::vector -{ - bool has_link_to(uint32_t id) const { - auto iter = std::find(begin(), end(), id); - return (iter != end()); - } - void remove_link(uint32_t id) { - uint32_t last = back(); - for (iterator iter = begin(); iter != end(); ++iter) { - if (*iter == id) { - *iter = last; - pop_back(); - return; - } - } - fprintf(stderr, "BAD missing link to remove: %u\n", id); - abort(); - } -}; -struct Node { - std::vector _links; - Node(uint32_t , uint32_t numLevels, uint32_t M) - : _links(numLevels) - { - for (uint32_t i = 0; i < _links.size(); ++i) { - _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); - } - } -}; -struct VisitedSet +HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(200), + _levelMultiplier(1.0 / log(1.0 * _M)), + _rndGen(), + _ops_counter(0) { - using Mark = unsigned short; - Mark *ptr; - Mark curval; - size_t sz; - VisitedSet(const VisitedSet &) = delete; - VisitedSet& operator=(const VisitedSet &) = delete; - explicit VisitedSet(size_t size) { - ptr = (Mark *)malloc(size * sizeof(Mark)); - curval = -1; - sz = size; - clear(); - } - void clear() { - ++curval; - if (curval == 0) { - memset(ptr, 0, sz * sizeof(Mark)); - ++curval; - } - } - ~VisitedSet() { free(ptr); } - void mark(size_t id) { ptr[id] = curval; } - bool isMarked(size_t id) const { return ptr[id] == curval; } -}; +} -struct VisitedSetPool -{ - std::unique_ptr lastUsed; - VisitedSetPool() { - lastUsed = std::make_unique(250); - } - ~VisitedSetPool() {} - VisitedSet &get(size_t size) { - if (size > lastUsed->sz) { - lastUsed = std::make_unique(size*2); - } else { - lastUsed->clear(); +// simple greedy search +HnswHit +HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + ++distcalls_simple; + if (dist < curPoint.dist) { + curPoint = HnswHit(n_id, SqDist(dist)); + keepGoing = true; + } } - return *lastUsed; - } -}; - -struct HnswHit { - double dist; - uint32_t docid; - HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} -}; - -struct GreaterDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (rhs.dist < lhs.dist); - } -}; -struct LesserDist { - bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { - return (lhs.dist < rhs.dist); - } -}; - -using NearestList = std::vector; - -struct NearestPriQ : std::priority_queue -{ -}; - -struct FurthestPriQ : std::priority_queue -{ - NearestList steal() { - NearestList result; - c.swap(result); - return result; - } - const NearestList& peek() const { return c; } -}; - -class HnswLikeNns : public NNS -{ -private: - std::vector _nodes; - uint32_t _entryId; - int _entryLevel; - uint32_t _M; - uint32_t _efConstruction; - double _levelMultiplier; - RndGen _rndGen; - VisitedSetPool _visitedSetPool; - size_t _ops_counter; - - double distance(Vector v, uint32_t id) const; - - double distance(uint32_t a, uint32_t b) const { - Vector v = _dva.get(a); - return distance(v, b); - } - - int randomLevel() { - double unif = _rndGen.nextUniform(); - double r = -log(1.0-unif) * _levelMultiplier; - return (int) r; - } - - uint32_t count_reachable() const; - void dumpStats() const; - -public: - HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva) - : NNS(numDims, dva), - _nodes(), - _entryId(0), - _entryLevel(-1), - _M(16), - _efConstruction(200), - _levelMultiplier(1.0 / log(1.0 * _M)), - _rndGen(), - _ops_counter(0) - { } + return curPoint; +} - ~HnswLikeNns() { dumpStats(); } - - LinkList& getLinkList(uint32_t docid, uint32_t level) { - // assert(docid < _nodes.size()); - // assert(level < _nodes[docid]._links.size()); - return _nodes[docid]._links[level]; +bool +HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + ++distcalls_heuristic; + if (dist < e.dist) return true; } + return false; +} - const LinkList& getLinkList(uint32_t docid, uint32_t level) const { - return _nodes[docid]._links[level]; +void +HnswLikeNns::addDoc(uint32_t docid) { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + track_ops(); + return; } - - // simple greedy search - HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) { - bool keepGoing = true; - while (keepGoing) { - keepGoing = false; - const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); - for (uint32_t n_id : neighbors) { - double dist = distance(vector, n_id); - ++distcalls_simple; - if (dist < curPoint.dist) { - curPoint = HnswHit(n_id, SqDist(dist)); - keepGoing = true; - } - } - } - return curPoint; + int searchLevel = _entryLevel; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; } - - void search_layer(Vector vector, FurthestPriQ &w, - uint32_t ef, uint32_t searchLevel); - - void search_layer_with_filter(Vector vector, FurthestPriQ &w, - uint32_t ef, uint32_t searchLevel, - const BitVector &blacklist); - - bool haveCloserDistance(HnswHit e, const LinkList &r) const { - for (uint32_t prevId : r) { - double dist = distance(e.docid, prevId); - ++distcalls_heuristic; - if (dist < e.dist) return true; - } - return false; + searchLevel = std::min(level, _entryLevel); + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel >= 0) { + search_layer(vector, w, _efConstruction, searchLevel); + LinkList neighbors = select_neighbors(w.peek(), _M); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; } - - LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; - - LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; - - void addDoc(uint32_t docid) override { - Vector vector = _dva.get(docid); - for (uint32_t id = _nodes.size(); id <= docid; ++id) { - _nodes.emplace_back(id, 0, _M); - } - int level = randomLevel(); - assert(_nodes[docid]._links.size() == 0); - _nodes[docid] = Node(docid, level+1, _M); - if (_entryLevel < 0) { - _entryId = docid; - _entryLevel = level; - track_ops(); - return; - } - int searchLevel = _entryLevel; - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); - while (searchLevel > level) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - searchLevel = std::min(level, _entryLevel); - FurthestPriQ w; - w.push(entryPoint); - while (searchLevel >= 0) { - search_layer(vector, w, _efConstruction, searchLevel); - LinkList neighbors = select_neighbors(w.peek(), _M); - connect_new_node(docid, neighbors, searchLevel); - each_shrink_ifneeded(neighbors, searchLevel); - --searchLevel; - } - if (level > _entryLevel) { - _entryLevel = level; - _entryId = docid; - } - track_ops(); + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; } + track_ops(); +} - void track_ops() { - _ops_counter++; - if ((_ops_counter % 10000) == 0) { - double div = _ops_counter; - fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); - fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); - fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); - fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); - fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); - fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); - fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); - fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); - fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); - fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); - fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); - fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); - } - } - - void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { - LinkList &links = getLinkList(from_id, level); - links.remove_link(remove_id); - } +void +HnswLikeNns::track_ops() { + _ops_counter++; + if ((_ops_counter % 10000) == 0) { + double div = _ops_counter; + fprintf(stderr, "add / remove ops: %zu\n", _ops_counter); + fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div); + fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div); + fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div); + fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div); + fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); + fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); + fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); + } +} - void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { - LinkList &my_links = getLinkList(my_id, level); - if (my_links.size() < 8) { - ++refill_needed_calls; - for (uint32_t repl_id : replacements) { - if (repl_id == my_id) continue; - if (my_links.has_link_to(repl_id)) continue; - LinkList &other_links = getLinkList(repl_id, level); - if (other_links.size() + 1 >= _M) continue; - other_links.push_back(my_id); - my_links.push_back(repl_id); - if (my_links.size() >= _M) return; - } +void +HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) { + LinkList &my_links = getLinkList(my_id, level); + if (my_links.size() < 8) { + ++refill_needed_calls; + for (uint32_t repl_id : replacements) { + if (repl_id == my_id) continue; + if (my_links.has_link_to(repl_id)) continue; + LinkList &other_links = getLinkList(repl_id, level); + if (other_links.size() + 1 >= _M) continue; + other_links.push_back(my_id); + my_links.push_back(repl_id); + if (my_links.size() >= _M) return; } } +} - void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); - - void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { - LinkList &links = getLinkList(shrink_id, level); - NearestList distances; - for (uint32_t n_id : links) { - double n_dist = distance(shrink_id, n_id); - ++distcalls_shrink; - distances.emplace_back(n_id, SqDist(n_dist)); - } - LinkList lostLinks; - LinkList oldLinks = links; - links = remove_weakest(distances, maxLinks, lostLinks); +void +HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) { + LinkList &links = getLinkList(shrink_id, level); + NearestList distances; + for (uint32_t n_id : links) { + double n_dist = distance(shrink_id, n_id); + ++distcalls_shrink; + distances.emplace_back(n_id, SqDist(n_dist)); + } + LinkList lostLinks; + LinkList oldLinks = links; + links = remove_weakest(distances, maxLinks, lostLinks); #define KEEP_SYM #ifdef KEEP_SYM - for (uint32_t lost_id : lostLinks) { - ++disconnected_for_symmetry; - remove_link_from(lost_id, shrink_id, level); - } + for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; + remove_link_from(lost_id, shrink_id, level); + } #define DO_REFILL_AFTER_KEEP_SYM #ifdef DO_REFILL_AFTER_KEEP_SYM - for (uint32_t lost_id : lostLinks) { - refill_ifneeded(lost_id, oldLinks, level); - } + for (uint32_t lost_id : lostLinks) { + refill_ifneeded(lost_id, oldLinks, level); + } #endif #endif - } - - void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); +} - void removeDoc(uint32_t docid) override { - Node &node = _nodes[docid]; - bool need_new_entrypoint = (docid == _entryId); - for (int level = node._links.size(); level-- > 0; ) { - LinkList my_links; - my_links.swap(node._links[level]); - for (uint32_t n_id : my_links) { - if (need_new_entrypoint) { - _entryId = n_id; - _entryLevel = level; - need_new_entrypoint = false; - } - remove_link_from(n_id, docid, level); - } - while (! my_links.empty()) { - uint32_t n_id = my_links.back(); - my_links.pop_back(); - refill_ifneeded(n_id, my_links, level); +void +HnswLikeNns::removeDoc(uint32_t docid) { + Node &node = _nodes[docid]; + bool need_new_entrypoint = (docid == _entryId); + for (int level = node._links.size(); level-- > 0; ) { + LinkList my_links; + my_links.swap(node._links[level]); + for (uint32_t n_id : my_links) { + if (need_new_entrypoint) { + _entryId = n_id; + _entryLevel = level; + need_new_entrypoint = false; } - } - node = Node(docid, 0, _M); - if (need_new_entrypoint) { - _entryLevel = -1; - _entryId = 0; - for (uint32_t i = 0; i < _nodes.size(); ++i) { - if (_nodes[i]._links.size() > 0) { - _entryId = i; - _entryLevel = _nodes[i]._links.size() - 1; - break; - } + remove_link_from(n_id, docid, level); + } + while (! my_links.empty()) { + uint32_t n_id = my_links.back(); + my_links.pop_back(); + refill_ifneeded(n_id, my_links, level); + } + } + node = Node(docid, 0, _M); + if (need_new_entrypoint) { + _entryLevel = -1; + _entryId = 0; + for (uint32_t i = 0; i < _nodes.size(); ++i) { + if (_nodes[i]._links.size() > 0) { + _entryId = i; + _entryLevel = _nodes[i]._links.size() - 1; + break; } } - track_ops(); } + track_ops(); +} - std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override { - std::vector result; - if (_entryLevel < 0) return result; - double entryDist = distance(vector, _entryId); - ++distcalls_other; - HnswHit entryPoint(_entryId, SqDist(entryDist)); - int searchLevel = _entryLevel; - while (searchLevel > 0) { - entryPoint = search_layer_simple(vector, entryPoint, searchLevel); - --searchLevel; - } - FurthestPriQ w; - w.push(entryPoint); - search_layer(vector, w, std::max(k, search_k), 0); - while (w.size() > k) { - w.pop(); - } - NearestList tmp = w.steal(); - std::sort(tmp.begin(), tmp.end(), LesserDist()); - result.reserve(tmp.size()); - for (const auto & hit : tmp) { - result.emplace_back(hit.docid, SqDist(hit.dist)); - } - return result; +std::vector +HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) { + std::vector result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; } - - std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; -}; + FurthestPriQ w; + w.push(entryPoint); + search_layer(vector, w, std::max(k, search_k), 0); + while (w.size() > k) { + w.pop(); + } + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(tmp.size()); + for (const auto & hit : tmp) { + result.emplace_back(hit.docid, SqDist(hit.dist)); + } + return result; +} double -- cgit v1.2.3 From 6355c5dd0b2dc355812414e836b8ca6bbe2bc8ca Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 26 Feb 2020 10:34:21 +0000 Subject: add common header file --- eval/src/tests/ann/hnsw-like.h | 203 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 eval/src/tests/ann/hnsw-like.h diff --git a/eval/src/tests/ann/hnsw-like.h b/eval/src/tests/ann/hnsw-like.h new file mode 100644 index 00000000000..36064c69860 --- /dev/null +++ b/eval/src/tests/ann/hnsw-like.h @@ -0,0 +1,203 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include +#include +#include +#include +#include "std-random.h" +#include "nns.h" + +struct LinkList : std::vector +{ + bool has_link_to(uint32_t id) const { + auto iter = std::find(begin(), end(), id); + return (iter != end()); + } + void remove_link(uint32_t id) { + uint32_t last = back(); + for (iterator iter = begin(); iter != end(); ++iter) { + if (*iter == id) { + *iter = last; + pop_back(); + return; + } + } + fprintf(stderr, "BAD missing link to remove: %u\n", id); + abort(); + } +}; + +struct Node { + std::vector _links; + Node(uint32_t , uint32_t numLevels, uint32_t M) + : _links(numLevels) + { + for (uint32_t i = 0; i < _links.size(); ++i) { + _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); + } + } +}; + +struct VisitedSet +{ + using Mark = unsigned short; + Mark *ptr; + Mark curval; + size_t sz; + VisitedSet(const VisitedSet &) = delete; + VisitedSet& operator=(const VisitedSet &) = delete; + explicit VisitedSet(size_t size) { + ptr = (Mark *)malloc(size * sizeof(Mark)); + curval = -1; + sz = size; + clear(); + } + void clear() { + ++curval; + if (curval == 0) { + memset(ptr, 0, sz * sizeof(Mark)); + ++curval; + } + } + ~VisitedSet() { free(ptr); } + void mark(size_t id) { ptr[id] = curval; } + bool isMarked(size_t id) const { return ptr[id] == curval; } +}; + +struct VisitedSetPool +{ + std::unique_ptr lastUsed; + VisitedSetPool() { + lastUsed = std::make_unique(250); + } + ~VisitedSetPool() {} + VisitedSet &get(size_t size) { + if (size > lastUsed->sz) { + lastUsed = std::make_unique(size*2); + } else { + lastUsed->clear(); + } + return *lastUsed; + } +}; + +struct HnswHit { + double dist; + uint32_t docid; + HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} +}; + +struct GreaterDist { + bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { + return (rhs.dist < lhs.dist); + } +}; +struct LesserDist { + bool operator() (const HnswHit &lhs, const HnswHit& rhs) const { + return (lhs.dist < rhs.dist); + } +}; + +using NearestList = std::vector; + +struct NearestPriQ : std::priority_queue +{ +}; + +struct FurthestPriQ : std::priority_queue +{ + NearestList steal() { + NearestList result; + c.swap(result); + return result; + } + const NearestList& peek() const { return c; } +}; + +class HnswLikeNns : public NNS +{ +private: + std::vector _nodes; + uint32_t _entryId; + int _entryLevel; + uint32_t _M; + uint32_t _efConstruction; + double _levelMultiplier; + RndGen _rndGen; + VisitedSetPool _visitedSetPool; + size_t _ops_counter; + + double distance(Vector v, uint32_t id) const; + + double distance(uint32_t a, uint32_t b) const { + Vector v = _dva.get(a); + return distance(v, b); + } + + int randomLevel() { + double unif = _rndGen.nextUniform(); + double r = -log(1.0-unif) * _levelMultiplier; + return (int) r; + } + + uint32_t count_reachable() const; + void dumpStats() const; + +public: + HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva); + ~HnswLikeNns() { dumpStats(); } + + LinkList& getLinkList(uint32_t docid, uint32_t level) { + return _nodes[docid]._links[level]; + } + + const LinkList& getLinkList(uint32_t docid, uint32_t level) const { + return _nodes[docid]._links[level]; + } + + HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel); + + void search_layer(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel); + void search_layer(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel); + void search_layer_with_filter(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist); + void search_layer_with_filter(Vector vector, FurthestPriQ &w, + VisitedSet &visited, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist); + + bool haveCloserDistance(HnswHit e, const LinkList &r) const; + + LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const; + + LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const; + + void addDoc(uint32_t docid) override; + + void track_ops(); + + void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) { + LinkList &links = getLinkList(from_id, level); + links.remove_link(remove_id); + } + + void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level); + + void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); + + void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level); + + void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); + + void removeDoc(uint32_t docid) override; + + std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override; + + std::vector topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; +}; -- cgit v1.2.3