diff options
Diffstat (limited to 'eval/src/tests/ann/xp-hnswlike-nns.cpp')
-rw-r--r-- | eval/src/tests/ann/xp-hnswlike-nns.cpp | 121 |
1 files changed, 111 insertions, 10 deletions
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp index 5cdbdd8efa3..90fc0fe2e92 100644 --- a/eval/src/tests/ann/xp-hnswlike-nns.cpp +++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp @@ -32,6 +32,11 @@ static size_t distcalls_heuristic; static size_t distcalls_shrink; static size_t distcalls_refill; static size_t refill_needed_calls; +static size_t shrink_needed_calls; +static size_t disconnected_weak_links; +static size_t disconnected_for_symmetry; +static size_t select_n_full; +static size_t select_n_partial; struct LinkList : std::vector<uint32_t> { @@ -76,6 +81,7 @@ struct VisitedSet ptr = (Mark *)malloc(size * sizeof(Mark)); curval = -1; sz = size; + clear(); } void clear() { ++curval; @@ -99,8 +105,9 @@ struct VisitedSetPool VisitedSet &get(size_t size) { if (size > lastUsed->sz) { lastUsed = std::make_unique<VisitedSet>(size*2); + } else { + lastUsed->clear(); } - lastUsed->clear(); return *lastUsed; } }; @@ -214,6 +221,10 @@ public: void search_layer(Vector vector, FurthestPriQ &w, uint32_t ef, uint32_t searchLevel); + void search_layer_with_filter(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist); + bool haveCloserDistance(HnswHit e, const LinkList &r) const { for (uint32_t prevId : r) { double dist = distance(e.docid, prevId); @@ -278,6 +289,10 @@ public: fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div); fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div); fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div); + fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div); + fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div); + fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div); + fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full); } } @@ -315,10 +330,19 @@ public: LinkList lostLinks; LinkList oldLinks = links; links = remove_weakest(distances, maxLinks, lostLinks); +#define KEEP_SYM +#ifdef KEEP_SYM for (uint32_t lost_id : lostLinks) { + ++disconnected_for_symmetry; remove_link_from(lost_id, shrink_id, level); + } +#define DO_REFILL_AFTER_KEEP_SYM +#ifdef DO_REFILL_AFTER_KEEP_SYM + for (uint32_t lost_id : lostLinks) { refill_ifneeded(lost_id, oldLinks, level); } +#endif +#endif } void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); @@ -337,7 +361,9 @@ public: } remove_link_from(n_id, docid, level); } - for (uint32_t n_id : my_links) { + while (! my_links.empty()) { + uint32_t n_id = my_links.back(); + my_links.pop_back(); refill_ifneeded(n_id, my_links, level); } } @@ -363,12 +389,12 @@ public: ++distcalls_other; HnswHit entryPoint(_entryId, SqDist(entryDist)); int searchLevel = _entryLevel; - FurthestPriQ w; - w.push(entryPoint); while (searchLevel > 0) { - search_layer(vector, w, std::min(k, search_k), searchLevel); + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); --searchLevel; } + FurthestPriQ w; + w.push(entryPoint); search_layer(vector, w, std::max(k, search_k), 0); while (w.size() > k) { w.pop(); @@ -381,8 +407,11 @@ public: } return result; } + + std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override; }; + double HnswLikeNns::distance(Vector v, uint32_t b) const { @@ -390,12 +419,40 @@ HnswLikeNns::distance(Vector v, uint32_t b) const return l2distCalc.l2sq_dist(v, w); } +std::vector<NnsHit> +HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + std::vector<NnsHit> result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + ++distcalls_other; + HnswHit entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + FurthestPriQ w; + w.push(entryPoint); + search_layer_with_filter(vector, w, std::max(k, search_k), 0, blacklist); + NearestList tmp = w.steal(); + std::sort(tmp.begin(), tmp.end(), LesserDist()); + result.reserve(std::min((size_t)k, tmp.size())); + for (const auto & hit : tmp) { + if (blacklist.isSet(hit.docid)) continue; + result.emplace_back(hit.docid, SqDist(hit.dist)); + if (result.size() == k) break; + } + return result; +} + void HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) { uint32_t maxLinks = (level > 0) ? _M : (2 * _M); for (uint32_t old_id : neighbors) { LinkList &oldLinks = getLinkList(old_id, level); if (oldLinks.size() > maxLinks) { + ++shrink_needed_calls; shrink_links(old_id, maxLinks, level); } } @@ -437,6 +494,44 @@ HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w, return; } +void +HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel, + const BitVector &blacklist) +{ + NearestPriQ candidates; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + + for (const HnswHit & entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + if (blacklist.isSet(entry.docid)) ++ef; + } + double limd = std::numeric_limits<double>::max(); + while (! candidates.empty()) { + HnswHit cand = candidates.top(); + if (cand.dist > limd) { + break; + } + candidates.pop(); + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + ++distcalls_search_layer; + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + if (blacklist.isSet(e_id)) continue; + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } +} + LinkList HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const { @@ -458,13 +553,13 @@ HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkL return result; } +#define NO_BACKFILL #ifdef NO_BACKFILL LinkList HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const { LinkList result; result.reserve(curMax+1); - bool needFiltering = (neighbors.size() > curMax); NearestPriQ w; for (const auto & entry : neighbors) { w.push(entry); @@ -472,12 +567,16 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con while (! w.empty()) { HnswHit e = w.top(); w.pop(); - if (needFiltering && haveCloserDistance(e, result)) { + if (haveCloserDistance(e, result)) { continue; } result.push_back(e.docid); - if (result.size() == curMax) return result; + if (result.size() == curMax) { + ++select_n_full; + return result; + } } + ++select_n_partial; return result; } #else @@ -502,10 +601,10 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con result.push_back(e.docid); if (result.size() == curMax) return result; } - if (result.size() * 4 < curMax) { + if (result.size() * 4 < _M) { for (uint32_t fill_id : backfill) { result.push_back(fill_id); - if (result.size() * 4 >= curMax) break; + if (result.size() * 2 >= _M) break; } } return result; @@ -576,7 +675,9 @@ HnswLikeNns::dumpStats() const { for (uint32_t n_id : link_list) { const LinkList &neigh_list = getLinkList(n_id, 0); if (! neigh_list.has_link_to(id)) { +#ifdef KEEP_SYM fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id); +#endif all_sym = false; } } |