summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-hnswlike-nns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-hnswlike-nns.cpp')
-rw-r--r--eval/src/tests/ann/xp-hnswlike-nns.cpp121
1 files changed, 111 insertions, 10 deletions
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp
index 5cdbdd8efa3..90fc0fe2e92 100644
--- a/eval/src/tests/ann/xp-hnswlike-nns.cpp
+++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp
@@ -32,6 +32,11 @@ static size_t distcalls_heuristic;
static size_t distcalls_shrink;
static size_t distcalls_refill;
static size_t refill_needed_calls;
+static size_t shrink_needed_calls;
+static size_t disconnected_weak_links;
+static size_t disconnected_for_symmetry;
+static size_t select_n_full;
+static size_t select_n_partial;
struct LinkList : std::vector<uint32_t>
{
@@ -76,6 +81,7 @@ struct VisitedSet
ptr = (Mark *)malloc(size * sizeof(Mark));
curval = -1;
sz = size;
+ clear();
}
void clear() {
++curval;
@@ -99,8 +105,9 @@ struct VisitedSetPool
VisitedSet &get(size_t size) {
if (size > lastUsed->sz) {
lastUsed = std::make_unique<VisitedSet>(size*2);
+ } else {
+ lastUsed->clear();
}
- lastUsed->clear();
return *lastUsed;
}
};
@@ -214,6 +221,10 @@ public:
void search_layer(Vector vector, FurthestPriQ &w,
uint32_t ef, uint32_t searchLevel);
+ void search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist);
+
bool haveCloserDistance(HnswHit e, const LinkList &r) const {
for (uint32_t prevId : r) {
double dist = distance(e.docid, prevId);
@@ -278,6 +289,10 @@ public:
fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
+ fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
+ fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
+ fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
}
}
@@ -315,10 +330,19 @@ public:
LinkList lostLinks;
LinkList oldLinks = links;
links = remove_weakest(distances, maxLinks, lostLinks);
+#define KEEP_SYM
+#ifdef KEEP_SYM
for (uint32_t lost_id : lostLinks) {
+ ++disconnected_for_symmetry;
remove_link_from(lost_id, shrink_id, level);
+ }
+#define DO_REFILL_AFTER_KEEP_SYM
+#ifdef DO_REFILL_AFTER_KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
refill_ifneeded(lost_id, oldLinks, level);
}
+#endif
+#endif
}
void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
@@ -337,7 +361,9 @@ public:
}
remove_link_from(n_id, docid, level);
}
- for (uint32_t n_id : my_links) {
+ while (! my_links.empty()) {
+ uint32_t n_id = my_links.back();
+ my_links.pop_back();
refill_ifneeded(n_id, my_links, level);
}
}
@@ -363,12 +389,12 @@ public:
++distcalls_other;
HnswHit entryPoint(_entryId, SqDist(entryDist));
int searchLevel = _entryLevel;
- FurthestPriQ w;
- w.push(entryPoint);
while (searchLevel > 0) {
- search_layer(vector, w, std::min(k, search_k), searchLevel);
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
--searchLevel;
}
+ FurthestPriQ w;
+ w.push(entryPoint);
search_layer(vector, w, std::max(k, search_k), 0);
while (w.size() > k) {
w.pop();
@@ -381,8 +407,11 @@ public:
}
return result;
}
+
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override;
};
+
double
HnswLikeNns::distance(Vector v, uint32_t b) const
{
@@ -390,12 +419,40 @@ HnswLikeNns::distance(Vector v, uint32_t b) const
return l2distCalc.l2sq_dist(v, w);
}
+std::vector<NnsHit>
+HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+ search_layer_with_filter(vector, w, std::max(k, search_k), 0, blacklist);
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(std::min((size_t)k, tmp.size()));
+ for (const auto & hit : tmp) {
+ if (blacklist.isSet(hit.docid)) continue;
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ if (result.size() == k) break;
+ }
+ return result;
+}
+
void
HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) {
uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
for (uint32_t old_id : neighbors) {
LinkList &oldLinks = getLinkList(old_id, level);
if (oldLinks.size() > maxLinks) {
+ ++shrink_needed_calls;
shrink_links(old_id, maxLinks, level);
}
}
@@ -437,6 +494,44 @@ HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w,
return;
}
+void
+HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist)
+{
+ NearestPriQ candidates;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ if (blacklist.isSet(entry.docid)) ++ef;
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ if (blacklist.isSet(e_id)) continue;
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
+ }
+ }
+ }
+ }
+}
+
LinkList
HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const
{
@@ -458,13 +553,13 @@ HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkL
return result;
}
+#define NO_BACKFILL
#ifdef NO_BACKFILL
LinkList
HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
{
LinkList result;
result.reserve(curMax+1);
- bool needFiltering = (neighbors.size() > curMax);
NearestPriQ w;
for (const auto & entry : neighbors) {
w.push(entry);
@@ -472,12 +567,16 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con
while (! w.empty()) {
HnswHit e = w.top();
w.pop();
- if (needFiltering && haveCloserDistance(e, result)) {
+ if (haveCloserDistance(e, result)) {
continue;
}
result.push_back(e.docid);
- if (result.size() == curMax) return result;
+ if (result.size() == curMax) {
+ ++select_n_full;
+ return result;
+ }
}
+ ++select_n_partial;
return result;
}
#else
@@ -502,10 +601,10 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con
result.push_back(e.docid);
if (result.size() == curMax) return result;
}
- if (result.size() * 4 < curMax) {
+ if (result.size() * 4 < _M) {
for (uint32_t fill_id : backfill) {
result.push_back(fill_id);
- if (result.size() * 4 >= curMax) break;
+ if (result.size() * 2 >= _M) break;
}
}
return result;
@@ -576,7 +675,9 @@ HnswLikeNns::dumpStats() const {
for (uint32_t n_id : link_list) {
const LinkList &neigh_list = getLinkList(n_id, 0);
if (! neigh_list.has_link_to(id)) {
+#ifdef KEEP_SYM
fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id);
+#endif
all_sym = false;
}
}