aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-02-26 15:57:08 +0100
committerGitHub <noreply@github.com>2020-02-26 15:57:08 +0100
commit2aa1ee8401290bfce3b97409e8fc634b8386f247 (patch)
tree0e4a926007fc45647886d36cad630d043a44c010
parent37f6c5b31cb2809a54c34dc3a4e51307f3320fbd (diff)
parent6355c5dd0b2dc355812414e836b8ca6bbe2bc8ca (diff)
Merge pull request #12321 from vespa-engine/arnej/rework-ann-filter-bm
Arnej/rework ann filter bm
-rw-r--r--eval/src/tests/ann/CMakeLists.txt10
-rw-r--r--eval/src/tests/ann/bruteforce-nns.h74
-rw-r--r--eval/src/tests/ann/extended-hnsw.cpp636
-rw-r--r--eval/src/tests/ann/find-with-nns.h12
-rw-r--r--eval/src/tests/ann/for-sift-top-k.h2
-rw-r--r--eval/src/tests/ann/gist_benchmark.cpp142
-rw-r--r--eval/src/tests/ann/hnsw-like.h203
-rw-r--r--eval/src/tests/ann/nns.h26
-rw-r--r--eval/src/tests/ann/point-vector.h30
-rw-r--r--eval/src/tests/ann/quality-nns.h42
-rw-r--r--eval/src/tests/ann/read-vecs.h45
-rw-r--r--eval/src/tests/ann/remove-bm.cpp434
-rw-r--r--eval/src/tests/ann/sift_benchmark.cpp305
-rw-r--r--eval/src/tests/ann/time-util.h9
-rw-r--r--eval/src/tests/ann/verify-top-k.h27
-rw-r--r--eval/src/tests/ann/xp-annoy-nns.cpp58
-rw-r--r--eval/src/tests/ann/xp-hnsw-wrap.cpp28
-rw-r--r--eval/src/tests/ann/xp-hnswlike-nns.cpp612
-rw-r--r--eval/src/tests/ann/xp-lsh-nns.cpp40
19 files changed, 1850 insertions, 885 deletions
diff --git a/eval/src/tests/ann/CMakeLists.txt b/eval/src/tests/ann/CMakeLists.txt
index 52b4d675d9c..0ba38994c01 100644
--- a/eval/src/tests/ann/CMakeLists.txt
+++ b/eval/src/tests/ann/CMakeLists.txt
@@ -10,6 +10,16 @@ vespa_add_executable(eval_sift_benchmark_app
vespaeval
)
+vespa_add_executable(eval_gist_benchmark_app
+ SOURCES
+ gist_benchmark.cpp
+ xp-annoy-nns.cpp
+ extended-hnsw.cpp
+ xp-lsh-nns.cpp
+ DEPENDS
+ vespaeval
+)
+
vespa_add_executable(eval_remove_bm_app
SOURCES
remove-bm.cpp
diff --git a/eval/src/tests/ann/bruteforce-nns.h b/eval/src/tests/ann/bruteforce-nns.h
new file mode 100644
index 00000000000..0c7c48654f7
--- /dev/null
+++ b/eval/src/tests/ann/bruteforce-nns.h
@@ -0,0 +1,74 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+std::vector<TopK> bruteforceResults;
+
+double computeDistance(const PointVector &query, uint32_t docid) {
+ const PointVector &docvector = generatedDocs[docid];
+ return l2distCalc.l2sq_dist(query, docvector);
+}
+
+struct BfHitComparator {
+ bool operator() (const Hit &lhs, const Hit& rhs) const {
+ if (lhs.distance < rhs.distance) return false;
+ if (lhs.distance > rhs.distance) return true;
+ return (lhs.docid > rhs.docid);
+ }
+};
+
+class BfHitHeap {
+private:
+ size_t _size;
+ vespalib::PriorityQueue<Hit, BfHitComparator> _priQ;
+public:
+ explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() {
+ _priQ.reserve(maxSize);
+ }
+ ~BfHitHeap() {}
+ void maybe_use(const Hit &hit) {
+ if (_priQ.size() < _size) {
+ _priQ.push(hit);
+ } else if (hit.distance < _priQ.front().distance) {
+ _priQ.front() = hit;
+ _priQ.adjust();
+ }
+ }
+ std::vector<Hit> bestHits() {
+ std::vector<Hit> result;
+ size_t i = _priQ.size();
+ result.resize(i);
+ while (i-- > 0) {
+ result[i] = _priQ.front();
+ _priQ.pop_front();
+ }
+ return result;
+ }
+};
+
+TopK bruteforce_nns(const PointVector &query) {
+ TopK result;
+ BfHitHeap heap(result.K);
+ for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) {
+ const PointVector &docvector = generatedDocs[docid];
+ double d = l2distCalc.l2sq_dist(query, docvector);
+ Hit h(docid, d);
+ heap.maybe_use(h);
+ }
+ std::vector<Hit> best = heap.bestHits();
+ for (size_t i = 0; i < result.K; ++i) {
+ result.hits[i] = best[i];
+ }
+ return result;
+}
+
+void verifyBF(uint32_t qid) {
+ const PointVector &query = generatedQueries[qid];
+ TopK &result = bruteforceResults[qid];
+ double min_distance = result.hits[0].distance;
+ for (uint32_t i = 0; i < NUM_DOCS; ++i) {
+ double dist = computeDistance(query, i);
+ if (dist < min_distance) {
+ fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance);
+ }
+ EXPECT_FALSE(dist+0.000001 < min_distance);
+ }
+}
diff --git a/eval/src/tests/ann/extended-hnsw.cpp b/eval/src/tests/ann/extended-hnsw.cpp
new file mode 100644
index 00000000000..fbc4bedec05
--- /dev/null
+++ b/eval/src/tests/ann/extended-hnsw.cpp
@@ -0,0 +1,636 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "hnsw-like.h"
+
+static size_t distcalls_simple;
+static size_t distcalls_search_layer;
+static size_t distcalls_other;
+static size_t distcalls_heuristic;
+static size_t distcalls_shrink;
+static size_t distcalls_refill;
+static size_t refill_needed_calls;
+static size_t shrink_needed_calls;
+static size_t disconnected_weak_links;
+static size_t disconnected_for_symmetry;
+static size_t select_n_full;
+static size_t select_n_partial;
+
+
+HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _nodes(),
+ _entryId(0),
+ _entryLevel(-1),
+ _M(16),
+ _efConstruction(200),
+ _levelMultiplier(1.0 / log(1.0 * _M)),
+ _rndGen(),
+ _ops_counter(0)
+{
+}
+
+// simple greedy search
+HnswHit
+HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
+ bool keepGoing = true;
+ while (keepGoing) {
+ keepGoing = false;
+ const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
+ for (uint32_t n_id : neighbors) {
+ double dist = distance(vector, n_id);
+ ++distcalls_simple;
+ if (dist < curPoint.dist) {
+ curPoint = HnswHit(n_id, SqDist(dist));
+ keepGoing = true;
+ }
+ }
+ }
+ return curPoint;
+}
+
+
+bool
+HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const {
+ for (uint32_t prevId : r) {
+ double dist = distance(e.docid, prevId);
+ ++distcalls_heuristic;
+ if (dist < e.dist) return true;
+ }
+ return false;
+}
+
+void
+HnswLikeNns::addDoc(uint32_t docid) {
+ Vector vector = _dva.get(docid);
+ for (uint32_t id = _nodes.size(); id <= docid; ++id) {
+ _nodes.emplace_back(id, 0, _M);
+ }
+ int level = randomLevel();
+ assert(_nodes[docid]._links.size() == 0);
+ _nodes[docid] = Node(docid, level+1, _M);
+ if (_entryLevel < 0) {
+ _entryId = docid;
+ _entryLevel = level;
+ track_ops();
+ return;
+ }
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+#undef MULTI_ENTRY_I
+#ifdef MULTI_ENTRY_I
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > level) {
+ search_layer(vector, w, visited, 5 * _M, searchLevel);
+ --searchLevel;
+ }
+#else
+ while (searchLevel > level) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+#endif
+ searchLevel = std::min(level, _entryLevel);
+ while (searchLevel >= 0) {
+ search_layer(vector, w, visited, _efConstruction, searchLevel);
+ LinkList neighbors = select_neighbors(w.peek(), _M);
+ connect_new_node(docid, neighbors, searchLevel);
+ each_shrink_ifneeded(neighbors, searchLevel);
+ --searchLevel;
+ }
+ if (level > _entryLevel) {
+ _entryLevel = level;
+ _entryId = docid;
+ }
+ track_ops();
+}
+
+void
+HnswLikeNns::track_ops() {
+ _ops_counter++;
+ if ((_ops_counter % 10000) == 0) {
+ double div = _ops_counter;
+ fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
+ fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
+ fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
+ fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
+ fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
+ }
+}
+
+#ifdef SIMPLE_REFILL
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() * 2 < _M) {
+ const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ ++refill_needed_calls;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= maxLinks) continue;
+ other_links.push_back(my_id);
+ my_links.push_back(repl_id);
+ if (my_links.size() >= _M) return;
+ }
+ }
+}
+#endif
+
+#define REFILL_ALL
+#ifdef REFILL_ALL
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() >= _M) return;
+ ++refill_needed_calls;
+ const uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ NearestPriQ w;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ const LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= maxLinks) continue;
+ double dist = distance(my_id, repl_id);
+ ++distcalls_refill;
+ w.emplace(repl_id, SqDist(dist));
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, my_links)) continue;
+ LinkList &other_links = getLinkList(e.docid, level);
+ my_links.push_back(e.docid);
+ other_links.push_back(my_id);
+ if (my_links.size() == _M) break;
+ }
+}
+#endif
+
+#ifdef REFILL_ONE
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() >= _M) return;
+ ++refill_needed_calls;
+ NearestPriQ w;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() >= _M) continue;
+ double dist = distance(my_id, repl_id);
+ ++distcalls_refill;
+ w.emplace(repl_id, SqDist(dist));
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, my_links)) continue;
+ LinkList &other_links = getLinkList(e.docid, level);
+ my_links.push_back(e.docid);
+ other_links.push_back(my_id);
+ return;
+ }
+}
+#endif
+
+void
+HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
+ LinkList &links = getLinkList(shrink_id, level);
+ NearestList distances;
+ for (uint32_t n_id : links) {
+ double n_dist = distance(shrink_id, n_id);
+ ++distcalls_shrink;
+ distances.emplace_back(n_id, SqDist(n_dist));
+ }
+ LinkList lostLinks;
+ LinkList oldLinks = links;
+ links = remove_weakest(distances, maxLinks, lostLinks);
+#define KEEP_SYM
+#ifdef KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
+ ++disconnected_for_symmetry;
+ remove_link_from(lost_id, shrink_id, level);
+ }
+#define DO_REFILL_AFTER_KEEP_SYM
+#ifdef DO_REFILL_AFTER_KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
+ refill_ifneeded(lost_id, oldLinks, level);
+ }
+#endif
+#endif
+}
+
+
+void
+HnswLikeNns::removeDoc(uint32_t docid) {
+ Node &node = _nodes[docid];
+ bool need_new_entrypoint = (docid == _entryId);
+ for (int level = node._links.size(); level-- > 0; ) {
+ LinkList my_links;
+ my_links.swap(node._links[level]);
+ for (uint32_t n_id : my_links) {
+ if (need_new_entrypoint) {
+ _entryId = n_id;
+ _entryLevel = level;
+ need_new_entrypoint = false;
+ }
+ remove_link_from(n_id, docid, level);
+ }
+ while (! my_links.empty()) {
+ uint32_t n_id = my_links.back();
+ my_links.pop_back();
+ refill_ifneeded(n_id, my_links, level);
+ }
+ }
+ node = Node(docid, 0, _M);
+ if (need_new_entrypoint) {
+ _entryLevel = -1;
+ _entryId = 0;
+ for (uint32_t i = 0; i < _nodes.size(); ++i) {
+ if (_nodes[i]._links.size() > 0) {
+ _entryId = i;
+ _entryLevel = _nodes[i]._links.size() - 1;
+ break;
+ }
+ }
+ }
+ track_ops();
+}
+
+std::vector<NnsHit>
+HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) {
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+#undef MULTI_ENTRY_S
+#ifdef MULTI_ENTRY_S
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > 0) {
+ search_layer(vector, w, visited, std::min(k, search_k), searchLevel);
+ --searchLevel;
+ }
+#else
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+#endif
+ search_layer(vector, w, visited, std::max(k, search_k), 0);
+ while (w.size() > k) {
+ w.pop();
+ }
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(tmp.size());
+ for (const auto & hit : tmp) {
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ }
+ return result;
+}
+
+
+double
+HnswLikeNns::distance(Vector v, uint32_t b) const
+{
+ Vector w = _dva.get(b);
+ return l2distCalc.l2sq_dist(v, w);
+}
+
+std::vector<NnsHit>
+HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+#ifdef MULTI_ENTRY_S
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel > 0) {
+ search_layer(vector, w, visited, std::min(k, search_k), searchLevel);
+ --searchLevel;
+ }
+#else
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+#endif
+ search_layer_with_filter(vector, w, visited, std::max(k, search_k), 0, blacklist);
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(std::min((size_t)k, tmp.size()));
+ for (const auto & hit : tmp) {
+ if (blacklist.isSet(hit.docid)) continue;
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ if (result.size() == k) break;
+ }
+ return result;
+}
+
+void
+HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) {
+ uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
+ for (uint32_t old_id : neighbors) {
+ LinkList &oldLinks = getLinkList(old_id, level);
+ if (oldLinks.size() > maxLinks) {
+ ++shrink_needed_calls;
+ shrink_links(old_id, maxLinks, level);
+ }
+ }
+}
+
+void
+HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel)
+{
+ NearestPriQ candidates;
+
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
+ }
+ }
+ }
+ }
+ return;
+}
+
+void
+HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist)
+{
+ NearestPriQ candidates;
+
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ if (blacklist.isSet(entry.docid)) ++ef;
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ if (blacklist.isSet(e_id)) continue;
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
+ }
+ }
+ }
+ }
+}
+
+LinkList
+HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (result.size() == curMax || haveCloserDistance(e, result)) {
+ lost.push_back(e.docid);
+ } else {
+ result.push_back(e.docid);
+ }
+ }
+ return result;
+}
+
+#define NO_BACKFILL
+#ifdef NO_BACKFILL
+LinkList
+HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (haveCloserDistance(e, result)) {
+ continue;
+ }
+ result.push_back(e.docid);
+ if (result.size() == curMax) {
+ ++select_n_full;
+ return result;
+ }
+ }
+ ++select_n_partial;
+ return result;
+}
+#else
+LinkList
+HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
+{
+ LinkList result;
+ result.reserve(curMax+1);
+ bool needFiltering = (neighbors.size() > curMax);
+ NearestPriQ w;
+ for (const auto & entry : neighbors) {
+ w.push(entry);
+ }
+ LinkList backfill;
+ while (! w.empty()) {
+ HnswHit e = w.top();
+ w.pop();
+ if (needFiltering && haveCloserDistance(e, result)) {
+ backfill.push_back(e.docid);
+ continue;
+ }
+ result.push_back(e.docid);
+ if (result.size() == curMax) return result;
+ }
+ if (result.size() * 4 < _M) {
+ for (uint32_t fill_id : backfill) {
+ result.push_back(fill_id);
+ if (result.size() * 2 >= _M) break;
+ }
+ }
+ return result;
+}
+#endif
+
+void
+HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level) {
+ LinkList &newLinks = getLinkList(id, level);
+ for (uint32_t neigh_id : neighbors) {
+ LinkList &oldLinks = getLinkList(neigh_id, level);
+ newLinks.push_back(neigh_id);
+ oldLinks.push_back(id);
+ }
+#define DISCONNECT_OLD_WEAK_LINKS
+#ifdef DISCONNECT_OLD_WEAK_LINKS
+ for (uint32_t i = 1; i < neighbors.size(); ++i) {
+ uint32_t n_1 = neighbors[i];
+ LinkList &links_1 = getLinkList(n_1, level);
+ for (uint32_t j = 0; j < i; ++j) {
+ uint32_t n_2 = neighbors[j];
+ if (links_1.has_link_to(n_2)) {
+ ++disconnected_weak_links;
+ LinkList &links_2 = getLinkList(n_2, level);
+ links_1.remove_link(n_2);
+ links_2.remove_link(n_1);
+ }
+ }
+ }
+#endif
+}
+
+uint32_t
+HnswLikeNns::count_reachable() const {
+ VisitedSet visited(_nodes.size());
+ int level = _entryLevel;
+ LinkList curList;
+ curList.push_back(_entryId);
+ visited.mark(_entryId);
+ uint32_t idx = 0;
+ while (level >= 0) {
+ while (idx < curList.size()) {
+ uint32_t id = curList[idx++];
+ const LinkList &links = getLinkList(id, level);
+ for (uint32_t n_id : links) {
+ if (visited.isMarked(n_id)) continue;
+ visited.mark(n_id);
+ curList.push_back(n_id);
+ }
+ }
+ --level;
+ idx = 0;
+ }
+ return curList.size();
+}
+
+void
+HnswLikeNns::dumpStats() const {
+ std::vector<uint32_t> levelCounts;
+ levelCounts.resize(_entryLevel + 2);
+ std::vector<uint32_t> outLinkHist;
+ outLinkHist.resize(2 * _M + 2);
+ uint32_t symmetrics = 0;
+ uint32_t level1links = 0;
+ uint32_t both_l_links = 0;
+ fprintf(stderr, "stats for HnswLikeNns with %zu nodes, entry level = %d, entry id = %u\n",
+ _nodes.size(), _entryLevel, _entryId);
+
+ for (uint32_t id = 0; id < _nodes.size(); ++id) {
+ const auto &node = _nodes[id];
+ uint32_t levels = node._links.size();
+ levelCounts[levels]++;
+ if (levels < 1) {
+ outLinkHist[0]++;
+ continue;
+ }
+ const LinkList &link_list = getLinkList(id, 0);
+ uint32_t numlinks = link_list.size();
+ outLinkHist[numlinks]++;
+ if (numlinks < 1) {
+ fprintf(stderr, "node with %u links: id %u\n", numlinks, id);
+ }
+ bool all_sym = true;
+ for (uint32_t n_id : link_list) {
+ const LinkList &neigh_list = getLinkList(n_id, 0);
+ if (! neigh_list.has_link_to(id)) {
+#ifdef KEEP_SYM
+ fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id);
+#endif
+ all_sym = false;
+ }
+ }
+ if (all_sym) ++symmetrics;
+ if (levels < 2) continue;
+ const LinkList &link_list_1 = getLinkList(id, 1);
+ for (uint32_t n_id : link_list_1) {
+ ++level1links;
+ if (link_list.has_link_to(n_id)) ++both_l_links;
+ }
+ }
+ for (uint32_t l = 0; l < levelCounts.size(); ++l) {
+ fprintf(stderr, "Nodes on %u levels: %u\n", l, levelCounts[l]);
+ }
+ fprintf(stderr, "reachable nodes %u / %zu\n",
+ count_reachable(), _nodes.size() - levelCounts[0]);
+ fprintf(stderr, "level 1 links overlapping on l0: %u / total: %u\n",
+ both_l_links, level1links);
+ for (uint32_t l = 0; l < outLinkHist.size(); ++l) {
+ if (outLinkHist[l] != 0) {
+ fprintf(stderr, "Nodes with %u outward links on L0: %u\n", l, outLinkHist[l]);
+ }
+ }
+ fprintf(stderr, "Symmetric in-out nodes: %u\n", symmetrics);
+}
+
+std::unique_ptr<NNS<float>>
+make_hnsw_nns(uint32_t numDims, const DocVectorAccess<float> &dva)
+{
+ return std::make_unique<HnswLikeNns>(numDims, dva);
+}
diff --git a/eval/src/tests/ann/find-with-nns.h b/eval/src/tests/ann/find-with-nns.h
new file mode 100644
index 00000000000..3481b403f86
--- /dev/null
+++ b/eval/src/tests/ann/find-with-nns.h
@@ -0,0 +1,12 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+TopK find_with_nns(uint32_t sk, NNS_API &nns, uint32_t qid) {
+ TopK result;
+ const PointVector &qv = generatedQueries[qid];
+ vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
+ auto rv = nns.topK(result.K, query, sk);
+ for (size_t i = 0; i < result.K; ++i) {
+ result.hits[i] = Hit(rv[i].docid, rv[i].sq.distance);
+ }
+ return result;
+}
diff --git a/eval/src/tests/ann/for-sift-top-k.h b/eval/src/tests/ann/for-sift-top-k.h
index ba91cb2aebc..8a659a507bc 100644
--- a/eval/src/tests/ann/for-sift-top-k.h
+++ b/eval/src/tests/ann/for-sift-top-k.h
@@ -6,7 +6,7 @@ struct TopK {
static constexpr size_t K = 100;
Hit hits[K];
- size_t recall(const TopK &other) {
+ size_t recall(const TopK &other) const {
size_t overlap = 0;
size_t i = 0;
size_t j = 0;
diff --git a/eval/src/tests/ann/gist_benchmark.cpp b/eval/src/tests/ann/gist_benchmark.cpp
new file mode 100644
index 00000000000..de8bff877e6
--- /dev/null
+++ b/eval/src/tests/ann/gist_benchmark.cpp
@@ -0,0 +1,142 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/vespalib/util/priority_queue.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <stdio.h>
+#include <chrono>
+
+#define NUM_DIMS 960
+#define NUM_DOCS 200000
+#define NUM_REACH 10000
+#define NUM_Q 1000
+
+#include "doc_vector_access.h"
+#include "nns.h"
+#include "for-sift-hit.h"
+#include "for-sift-top-k.h"
+#include "time-util.h"
+#include "point-vector.h"
+#include "read-vecs.h"
+#include "bruteforce-nns.h"
+
+using NNS_API = NNS<float>;
+
+TEST("require that brute force works") {
+ TimePoint bef = std::chrono::steady_clock::now();
+ fprintf(stderr, "generating %u brute force results\n", NUM_Q);
+ bruteforceResults.reserve(NUM_Q);
+ for (uint32_t cnt = 0; cnt < NUM_Q; ++cnt) {
+ const PointVector &query = generatedQueries[cnt];
+ bruteforceResults.emplace_back(bruteforce_nns(query));
+ }
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "timing for brute force: %.3f ms = %.3f ms per query\n",
+ to_ms(aft - bef), to_ms(aft - bef)/NUM_Q);
+ for (int cnt = 0; cnt < NUM_Q; cnt = (cnt+1)*2) {
+ verifyBF(cnt);
+ }
+}
+
+#include "find-with-nns.h"
+#include "verify-top-k.h"
+
+void timing_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list) {
+ for (uint32_t search_k : sk_list) {
+ TimePoint bef = std::chrono::steady_clock::now();
+ for (int cnt = 0; cnt < NUM_Q; ++cnt) {
+ find_with_nns(search_k, nns, cnt);
+ }
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "timing for %s search_k=%u: %.3f ms = %.3f ms/q\n",
+ name, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q);
+ }
+}
+
+#include "quality-nns.h"
+
+template <typename FUNC>
+void bm_nns_simple(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ std::unique_ptr<NNS_API> nnsp = creator();
+ NNS_API &nns = *nnsp;
+ fprintf(stderr, "trying %s indexing...\n", name);
+ TimePoint bef = std::chrono::steady_clock::now();
+ for (uint32_t i = 0; i < NUM_DOCS; ++i) {
+ nns.addDoc(i);
+ }
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, NUM_DOCS, to_ms(aft - bef));
+ timing_nns(name, nns, sk_list);
+ fprintf(stderr, "Quality for %s [A] clean build with %u documents:\n", name, NUM_DOCS);
+ quality_nns(nns, sk_list);
+}
+
+template <typename FUNC>
+void benchmark_nns(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ bm_nns_simple(name, creator, sk_list);
+}
+
+#if 0
+TEST("require that Locality Sensitive Hashing mostly works") {
+ DocVectorAdapter adapter;
+ auto creator = [&adapter]() { return make_rplsh_nns(NUM_DIMS, adapter); };
+ benchmark_nns("RPLSH", creator, { 200, 1000 });
+}
+#endif
+
+#if 0
+TEST("require that Annoy via NNS api mostly works") {
+ DocVectorAdapter adapter;
+ auto creator = [&adapter]() { return make_annoy_nns(NUM_DIMS, adapter); };
+ benchmark_nns("Annoy", creator, { 8000, 10000 });
+}
+#endif
+
+#if 1
+TEST("require that HNSW via NNS api mostly works") {
+ DocVectorAdapter adapter;
+ auto creator = [&adapter]() { return make_hnsw_nns(NUM_DIMS, adapter); };
+ benchmark_nns("HNSW-like", creator, { 100, 150, 200 });
+}
+#endif
+
+#if 0
+TEST("require that HNSW wrapped api mostly works") {
+ DocVectorAdapter adapter;
+ auto creator = [&adapter]() { return make_hnsw_wrap(NUM_DIMS, adapter); };
+ benchmark_nns("HNSW-wrap", creator, { 100, 150, 200 });
+}
+#endif
+
+/**
+ * Before running the benchmark the ANN_GIST1M data set must be downloaded and extracted:
+ * wget ftp://ftp.irisa.fr/local/texmex/corpus/gist.tar.gz
+ * tar -xf gist.tar.gz
+ *
+ * The benchmark program will load the data set from $HOME/gist if no directory is specified.
+ *
+ * More information about the dataset is found here: http://corpus-texmex.irisa.fr/.
+ */
+int main(int argc, char **argv) {
+ TEST_MASTER.init(__FILE__);
+ std::string data_set = "gist";
+ std::string data_dir = ".";
+ if (argc > 2) {
+ data_set = argv[1];
+ data_dir = argv[2];
+ } else if (argc > 1) {
+ data_dir = argv[1];
+ } else {
+ char *home = getenv("HOME");
+ if (home) {
+ data_dir = home;
+ data_dir += "/" + data_set;
+ }
+ }
+ read_data(data_dir, data_set);
+ TEST_RUN_ALL();
+ return (TEST_MASTER.fini() ? 0 : 1);
+}
diff --git a/eval/src/tests/ann/hnsw-like.h b/eval/src/tests/ann/hnsw-like.h
new file mode 100644
index 00000000000..36064c69860
--- /dev/null
+++ b/eval/src/tests/ann/hnsw-like.h
@@ -0,0 +1,203 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <algorithm>
+#include <assert.h>
+#include <queue>
+#include <cinttypes>
+#include "std-random.h"
+#include "nns.h"
+
+struct LinkList : std::vector<uint32_t>
+{
+ bool has_link_to(uint32_t id) const {
+ auto iter = std::find(begin(), end(), id);
+ return (iter != end());
+ }
+ void remove_link(uint32_t id) {
+ uint32_t last = back();
+ for (iterator iter = begin(); iter != end(); ++iter) {
+ if (*iter == id) {
+ *iter = last;
+ pop_back();
+ return;
+ }
+ }
+ fprintf(stderr, "BAD missing link to remove: %u\n", id);
+ abort();
+ }
+};
+
+struct Node {
+ std::vector<LinkList> _links;
+ Node(uint32_t , uint32_t numLevels, uint32_t M)
+ : _links(numLevels)
+ {
+ for (uint32_t i = 0; i < _links.size(); ++i) {
+ _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1));
+ }
+ }
+};
+
+struct VisitedSet
+{
+ using Mark = unsigned short;
+ Mark *ptr;
+ Mark curval;
+ size_t sz;
+ VisitedSet(const VisitedSet &) = delete;
+ VisitedSet& operator=(const VisitedSet &) = delete;
+ explicit VisitedSet(size_t size) {
+ ptr = (Mark *)malloc(size * sizeof(Mark));
+ curval = -1;
+ sz = size;
+ clear();
+ }
+ void clear() {
+ ++curval;
+ if (curval == 0) {
+ memset(ptr, 0, sz * sizeof(Mark));
+ ++curval;
+ }
+ }
+ ~VisitedSet() { free(ptr); }
+ void mark(size_t id) { ptr[id] = curval; }
+ bool isMarked(size_t id) const { return ptr[id] == curval; }
+};
+
+struct VisitedSetPool
+{
+ std::unique_ptr<VisitedSet> lastUsed;
+ VisitedSetPool() {
+ lastUsed = std::make_unique<VisitedSet>(250);
+ }
+ ~VisitedSetPool() {}
+ VisitedSet &get(size_t size) {
+ if (size > lastUsed->sz) {
+ lastUsed = std::make_unique<VisitedSet>(size*2);
+ } else {
+ lastUsed->clear();
+ }
+ return *lastUsed;
+ }
+};
+
+struct HnswHit {
+ double dist;
+ uint32_t docid;
+ HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {}
+};
+
+struct GreaterDist {
+ bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
+ return (rhs.dist < lhs.dist);
+ }
+};
+struct LesserDist {
+ bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
+ return (lhs.dist < rhs.dist);
+ }
+};
+
+using NearestList = std::vector<HnswHit>;
+
+struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist>
+{
+};
+
+struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist>
+{
+ NearestList steal() {
+ NearestList result;
+ c.swap(result);
+ return result;
+ }
+ const NearestList& peek() const { return c; }
+};
+
+class HnswLikeNns : public NNS<float>
+{
+private:
+ std::vector<Node> _nodes;
+ uint32_t _entryId;
+ int _entryLevel;
+ uint32_t _M;
+ uint32_t _efConstruction;
+ double _levelMultiplier;
+ RndGen _rndGen;
+ VisitedSetPool _visitedSetPool;
+ size_t _ops_counter;
+
+ double distance(Vector v, uint32_t id) const;
+
+ double distance(uint32_t a, uint32_t b) const {
+ Vector v = _dva.get(a);
+ return distance(v, b);
+ }
+
+ int randomLevel() {
+ double unif = _rndGen.nextUniform();
+ double r = -log(1.0-unif) * _levelMultiplier;
+ return (int) r;
+ }
+
+ uint32_t count_reachable() const;
+ void dumpStats() const;
+
+public:
+ HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva);
+ ~HnswLikeNns() { dumpStats(); }
+
+ LinkList& getLinkList(uint32_t docid, uint32_t level) {
+ return _nodes[docid]._links[level];
+ }
+
+ const LinkList& getLinkList(uint32_t docid, uint32_t level) const {
+ return _nodes[docid]._links[level];
+ }
+
+ HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel);
+
+ void search_layer(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel);
+ void search_layer(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel);
+ void search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist);
+ void search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ VisitedSet &visited,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist);
+
+ bool haveCloserDistance(HnswHit e, const LinkList &r) const;
+
+ LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const;
+
+ LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const;
+
+ void addDoc(uint32_t docid) override;
+
+ void track_ops();
+
+ void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
+ LinkList &links = getLinkList(from_id, level);
+ links.remove_link(remove_id);
+ }
+
+ void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level);
+
+ void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level);
+
+ void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level);
+
+ void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
+
+ void removeDoc(uint32_t docid) override;
+
+ std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override;
+
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override;
+};
diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h
index ffe2882188e..ef3e4b5d69c 100644
--- a/eval/src/tests/ann/nns.h
+++ b/eval/src/tests/ann/nns.h
@@ -37,6 +37,31 @@ struct NnsHitComparatorLessDocid {
}
};
+class BitVector {
+private:
+ std::vector<uint64_t> _bits;
+public:
+ BitVector(size_t sz) : _bits((sz+63)/64) {}
+ BitVector& setBit(size_t idx) {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ _bits[idx/64] |= mask;
+ return *this;
+ }
+ bool isSet(size_t idx) const {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ uint64_t word = _bits[idx/64];
+ return (word & mask) != 0;
+ }
+ BitVector& clearBit(size_t idx) {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ _bits[idx/64] &= ~mask;
+ return *this;
+ }
+};
+
template <typename FltType = float>
class NNS
{
@@ -50,6 +75,7 @@ public:
using Vector = vespalib::ConstArrayRef<FltType>;
virtual std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) = 0;
+ virtual std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) = 0;
virtual ~NNS() {}
protected:
uint32_t _numDims;
diff --git a/eval/src/tests/ann/point-vector.h b/eval/src/tests/ann/point-vector.h
new file mode 100644
index 00000000000..eca60e11194
--- /dev/null
+++ b/eval/src/tests/ann/point-vector.h
@@ -0,0 +1,30 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+struct PointVector {
+ float v[NUM_DIMS];
+ using ConstArr = vespalib::ConstArrayRef<float>;
+ operator ConstArr() const { return ConstArr(v, NUM_DIMS); }
+};
+
+static PointVector *aligned_alloc(size_t num) {
+ size_t num_bytes = num * sizeof(PointVector);
+ double mega_bytes = num_bytes / (1024.0*1024.0);
+ fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes);
+ char *mem = (char *)malloc(num_bytes + 512);
+ mem += 512;
+ size_t val = (size_t)mem;
+ size_t unalign = val % 512;
+ mem -= unalign;
+ return reinterpret_cast<PointVector *>(mem);
+}
+
+static PointVector *generatedQueries = aligned_alloc(NUM_Q);
+static PointVector *generatedDocs = aligned_alloc(NUM_DOCS);
+
+struct DocVectorAdapter : public DocVectorAccess<float>
+{
+ vespalib::ConstArrayRef<float> get(uint32_t docid) const override {
+ ASSERT_TRUE(docid < NUM_DOCS);
+ return generatedDocs[docid];
+ }
+};
diff --git a/eval/src/tests/ann/quality-nns.h b/eval/src/tests/ann/quality-nns.h
new file mode 100644
index 00000000000..9ac37f0ef04
--- /dev/null
+++ b/eval/src/tests/ann/quality-nns.h
@@ -0,0 +1,42 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+bool reach_with_nns_k(NNS_API &nns, uint32_t docid, uint32_t k) {
+ const PointVector &qv = generatedDocs[docid];
+ vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
+ auto rv = nns.topK(k, query, k);
+ if (rv.size() != k) {
+ fprintf(stderr, "Result/K=%u from query for %u is %zu hits\n",
+ k, docid, rv.size());
+ return false;
+ }
+ if (rv[0].docid != docid) {
+ if (rv[0].sq.distance != 0.0)
+ fprintf(stderr, "Expected/K=%u to find %u but got %u with sq distance %.3f\n",
+ k, docid, rv[0].docid, rv[0].sq.distance);
+ }
+ return (rv[0].docid == docid || rv[0].sq.distance == 0.0);
+}
+
+void quality_nns(NNS_API &nns, std::vector<uint32_t> sk_list) {
+ for (uint32_t search_k : sk_list) {
+ double sum_recall = 0;
+ for (int cnt = 0; cnt < NUM_Q; ++cnt) {
+ sum_recall += verify_nns_quality(search_k, nns, cnt);
+ }
+ fprintf(stderr, "Overall average recall: %.2f\n", sum_recall / NUM_Q);
+ }
+ for (uint32_t search_k : { 1, 10, 100, 1000 }) {
+ TimePoint bef = std::chrono::steady_clock::now();
+ uint32_t reached = 0;
+ for (uint32_t i = 0; i < NUM_REACH; ++i) {
+ uint32_t target = i * (NUM_DOCS / NUM_REACH);
+ if (reach_with_nns_k(nns, target, search_k)) ++reached;
+ }
+ fprintf(stderr, "Could reach %u of %u documents with k=%u\n",
+ reached, NUM_REACH, search_k);
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "reach time k=%u: %.3f ms = %.3f ms/q\n",
+ search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_REACH);
+ if (reached == NUM_REACH) break;
+ }
+}
diff --git a/eval/src/tests/ann/read-vecs.h b/eval/src/tests/ann/read-vecs.h
new file mode 100644
index 00000000000..39c2a332710
--- /dev/null
+++ b/eval/src/tests/ann/read-vecs.h
@@ -0,0 +1,45 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+void read_queries(std::string fn) {
+ int fd = open(fn.c_str(), O_RDONLY);
+ ASSERT_TRUE(fd > 0);
+ int d;
+ size_t rv;
+ fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str());
+ for (uint32_t qid = 0; qid < NUM_Q; ++qid) {
+ rv = read(fd, &d, 4);
+ ASSERT_EQUAL(rv, 4u);
+ ASSERT_EQUAL(d, NUM_DIMS);
+ rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float));
+ ASSERT_EQUAL(rv, sizeof(PointVector));
+ }
+ close(fd);
+}
+
+void read_docs(std::string fn) {
+ int fd = open(fn.c_str(), O_RDONLY);
+ ASSERT_TRUE(fd > 0);
+ int d;
+ size_t rv;
+ fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str());
+ for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) {
+ rv = read(fd, &d, 4);
+ ASSERT_EQUAL(rv, 4u);
+ ASSERT_EQUAL(d, NUM_DIMS);
+ rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float));
+ ASSERT_EQUAL(rv, sizeof(PointVector));
+ }
+ close(fd);
+}
+
+void read_data(const std::string& dir, const std::string& data_set) {
+ fprintf(stderr, "read data set '%s' from directory '%s'\n", data_set.c_str(), dir.c_str());
+ TimePoint bef = std::chrono::steady_clock::now();
+ read_queries(dir + "/" + data_set + "_query.fvecs");
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef));
+ bef = std::chrono::steady_clock::now();
+ read_docs(dir + "/" + data_set + "_base.fvecs");
+ aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef));
+}
diff --git a/eval/src/tests/ann/remove-bm.cpp b/eval/src/tests/ann/remove-bm.cpp
index be010552ab8..546c2cfd75e 100644
--- a/eval/src/tests/ann/remove-bm.cpp
+++ b/eval/src/tests/ann/remove-bm.cpp
@@ -13,174 +13,17 @@
#define NUM_DOCS 250000
#define NUM_DOCS_REMOVE 50000
#define EFFECTIVE_DOCS (NUM_DOCS - NUM_DOCS_REMOVE)
+#define NUM_REACH 10000
#define NUM_Q 1000
#include "doc_vector_access.h"
#include "nns.h"
#include "for-sift-hit.h"
#include "for-sift-top-k.h"
-
-std::vector<TopK> bruteforceResults;
-std::vector<float> tmp_v(NUM_DIMS);
-
-struct PointVector {
- float v[NUM_DIMS];
- using ConstArr = vespalib::ConstArrayRef<float>;
- operator ConstArr() const { return ConstArr(v, NUM_DIMS); }
-};
-
-static PointVector *aligned_alloc(size_t num) {
- size_t sz = num * sizeof(PointVector);
- double mega_bytes = sz / (1024.0*1024.0);
- fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes);
- char *mem = (char *)malloc(sz + 512);
- mem += 512;
- size_t val = (size_t)mem;
- size_t unalign = val % 512;
- mem -= unalign;
- return reinterpret_cast<PointVector *>(mem);
-}
-
-static PointVector *generatedQueries = aligned_alloc(NUM_Q);
-static PointVector *generatedDocs = aligned_alloc(NUM_DOCS);
-
-struct DocVectorAdapter : public DocVectorAccess<float>
-{
- vespalib::ConstArrayRef<float> get(uint32_t docid) const override {
- ASSERT_TRUE(docid < NUM_DOCS);
- return generatedDocs[docid];
- }
-};
-
-double computeDistance(const PointVector &query, uint32_t docid) {
- const PointVector &docvector = generatedDocs[docid];
- return l2distCalc.l2sq_dist(query, docvector, tmp_v);
-}
-
-void read_queries(std::string fn) {
- int fd = open(fn.c_str(), O_RDONLY);
- ASSERT_TRUE(fd > 0);
- int d;
- size_t rv;
- fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str());
- for (uint32_t qid = 0; qid < NUM_Q; ++qid) {
- rv = read(fd, &d, 4);
- ASSERT_EQUAL(rv, 4u);
- ASSERT_EQUAL(d, NUM_DIMS);
- rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float));
- ASSERT_EQUAL(rv, sizeof(PointVector));
- }
- close(fd);
-}
-
-void read_docs(std::string fn) {
- int fd = open(fn.c_str(), O_RDONLY);
- ASSERT_TRUE(fd > 0);
- int d;
- size_t rv;
- fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str());
- for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) {
- rv = read(fd, &d, 4);
- ASSERT_EQUAL(rv, 4u);
- ASSERT_EQUAL(d, NUM_DIMS);
- rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float));
- ASSERT_EQUAL(rv, sizeof(PointVector));
- }
- close(fd);
-}
-
-using TimePoint = std::chrono::steady_clock::time_point;
-using Duration = std::chrono::steady_clock::duration;
-
-double to_ms(Duration elapsed) {
- std::chrono::duration<double, std::milli> ms(elapsed);
- return ms.count();
-}
-
-void read_data(std::string dir) {
- TimePoint bef = std::chrono::steady_clock::now();
- read_queries(dir + "/gist_query.fvecs");
- TimePoint aft = std::chrono::steady_clock::now();
- fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef));
- bef = std::chrono::steady_clock::now();
- read_docs(dir + "/gist_base.fvecs");
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef));
-}
-
-
-struct BfHitComparator {
- bool operator() (const Hit &lhs, const Hit& rhs) const {
- if (lhs.distance < rhs.distance) return false;
- if (lhs.distance > rhs.distance) return true;
- return (lhs.docid > rhs.docid);
- }
-};
-
-class BfHitHeap {
-private:
- size_t _size;
- vespalib::PriorityQueue<Hit, BfHitComparator> _priQ;
-public:
- explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() {
- _priQ.reserve(maxSize);
- }
- ~BfHitHeap() {}
- void maybe_use(const Hit &hit) {
- if (_priQ.size() < _size) {
- _priQ.push(hit);
- } else if (hit.distance < _priQ.front().distance) {
- _priQ.front() = hit;
- _priQ.adjust();
- }
- }
- std::vector<Hit> bestHits() {
- std::vector<Hit> result;
- size_t i = _priQ.size();
- result.resize(i);
- while (i-- > 0) {
- result[i] = _priQ.front();
- _priQ.pop_front();
- }
- return result;
- }
-};
-
-TopK bruteforce_nns(const PointVector &query) {
- TopK result;
- BfHitHeap heap(result.K);
- for (uint32_t docid = 0; docid < EFFECTIVE_DOCS; ++docid) {
- const PointVector &docvector = generatedDocs[docid];
- double d = l2distCalc.l2sq_dist(query, docvector, tmp_v);
- Hit h(docid, d);
- heap.maybe_use(h);
- }
- std::vector<Hit> best = heap.bestHits();
- for (size_t i = 0; i < result.K; ++i) {
- result.hits[i] = best[i];
- }
- return result;
-}
-
-void verifyBF(uint32_t qid) {
- const PointVector &query = generatedQueries[qid];
- TopK &result = bruteforceResults[qid];
- double min_distance = result.hits[0].distance;
- std::vector<double> all_c2;
- for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) {
- double dist = computeDistance(query, i);
- if (dist < min_distance) {
- fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance);
- }
- EXPECT_FALSE(dist+0.000001 < min_distance);
- if (min_distance > 0.0) all_c2.push_back(dist / min_distance);
- }
- if (all_c2.size() != EFFECTIVE_DOCS) return;
- std::sort(all_c2.begin(), all_c2.end());
- for (uint32_t idx : { 1, 3, 10, 30, 100, 300, 1000, 3000, EFFECTIVE_DOCS/2, EFFECTIVE_DOCS-1}) {
- fprintf(stderr, "c2-factor[%u] = %.3f\n", idx, all_c2[idx]);
- }
-}
+#include "time-util.h"
+#include "point-vector.h"
+#include "read-vecs.h"
+#include "bruteforce-nns.h"
using NNS_API = NNS<float>;
@@ -221,83 +64,8 @@ TEST("require that brute force works") {
}
}
-bool reach_with_nns_1(NNS_API &nns, uint32_t docid) {
- const PointVector &qv = generatedDocs[docid];
- vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
- auto rv = nns.topK(1, query, 1);
- if (rv.size() != 1) {
- fprintf(stderr, "Result/A from query for %u is %zu hits\n", docid, rv.size());
- return false;
- }
- if (rv[0].docid != docid) {
- if (rv[0].sq.distance != 0.0)
- fprintf(stderr, "Expected/A to find %u but got %u with sq distance %.3f\n",
- docid, rv[0].docid, rv[0].sq.distance);
- }
- return (rv[0].docid == docid || rv[0].sq.distance == 0.0);
-}
-
-bool reach_with_nns_100(NNS_API &nns, uint32_t docid) {
- const PointVector &qv = generatedDocs[docid];
- vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
- auto rv = nns.topK(10, query, 100);
- if (rv.size() != 10) {
- fprintf(stderr, "Result/B from query for %u is %zu hits\n", docid, rv.size());
- }
- if (rv[0].docid != docid) {
- if (rv[0].sq.distance != 0.0)
- fprintf(stderr, "Expected/B to find %u but got %u with sq distance %.3f\n",
- docid, rv[0].docid, rv[0].sq.distance);
- }
- return (rv[0].docid == docid || rv[0].sq.distance == 0.0);
-}
-
-bool reach_with_nns_1k(NNS_API &nns, uint32_t docid) {
- const PointVector &qv = generatedDocs[docid];
- vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
- auto rv = nns.topK(10, query, 1000);
- if (rv.size() != 10) {
- fprintf(stderr, "Result/C from query for %u is %zu hits\n", docid, rv.size());
- }
- if (rv[0].docid != docid) {
- if (rv[0].sq.distance != 0.0)
- fprintf(stderr, "Expected/C to find %u but got %u with sq distance %.3f\n",
- docid, rv[0].docid, rv[0].sq.distance);
- }
- return (rv[0].docid == docid || rv[0].sq.distance == 0.0);
-}
-
-TopK find_with_nns(uint32_t sk, NNS_API &nns, uint32_t qid) {
- TopK result;
- const PointVector &qv = generatedQueries[qid];
- vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
- auto rv = nns.topK(result.K, query, sk);
- for (size_t i = 0; i < result.K; ++i) {
- result.hits[i] = Hit(rv[i].docid, rv[i].sq.distance);
- }
- return result;
-}
-
-void verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) {
- TopK perfect = bruteforceResults[qid];
- TopK result = find_with_nns(sk, nns, qid);
- int recall = perfect.recall(result);
- EXPECT_TRUE(recall > 40);
- double sum_error = 0.0;
- double c_factor = 1.0;
- for (size_t i = 0; i < result.K; ++i) {
- double factor = (result.hits[i].distance / perfect.hits[i].distance);
- if (factor < 0.99 || factor > 25) {
- fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n",
- i, result.hits[i].distance, perfect.hits[i].distance);
- }
- sum_error += factor;
- c_factor = std::max(c_factor, factor);
- }
- EXPECT_TRUE(c_factor < 1.5);
- fprintf(stderr, "quality sk=%u: query %u: recall %d c2-factor %.3f avg c2: %.3f\n",
- sk, qid, recall, c_factor, sum_error / result.K);
-}
+#include "find-with-nns.h"
+#include "verify-top-k.h"
void timing_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list) {
for (uint32_t search_k : sk_list) {
@@ -311,64 +79,22 @@ void timing_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list) {
}
}
-void quality_nns(NNS_API &nns, std::vector<uint32_t> sk_list) {
- for (uint32_t search_k : sk_list) {
- for (int cnt = 0; cnt < NUM_Q; ++cnt) {
- verify_nns_quality(search_k, nns, cnt);
- }
- }
- uint32_t reached = 0;
- for (uint32_t i = 0; i < 20000; ++i) {
- if (reach_with_nns_1(nns, i)) ++reached;
- }
- fprintf(stderr, "Could reach %u of 20000 first documents with k=1\n", reached);
- reached = 0;
- for (uint32_t i = 0; i < 20000; ++i) {
- if (reach_with_nns_100(nns, i)) ++reached;
- }
- fprintf(stderr, "Could reach %u of 20000 first documents with k=100\n", reached);
- reached = 0;
- for (uint32_t i = 0; i < 20000; ++i) {
- if (reach_with_nns_1k(nns, i)) ++reached;
- }
- fprintf(stderr, "Could reach %u of 20000 first documents with k=1000\n", reached);
-}
+#include "quality-nns.h"
-void benchmark_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list) {
+template <typename FUNC>
+void bm_nns_simple(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ std::unique_ptr<NNS_API> nnsp = creator();
+ NNS_API &nns = *nnsp;
fprintf(stderr, "trying %s indexing...\n", name);
-
-#if 0
- TimePoint bef = std::chrono::steady_clock::now();
- for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
- nns.addDoc(EFFECTIVE_DOCS + i);
- }
- for (uint32_t i = 0; i < EFFECTIVE_DOCS - NUM_DOCS_REMOVE; ++i) {
- nns.addDoc(i);
- }
- for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
- nns.removeDoc(EFFECTIVE_DOCS + i);
- nns.addDoc(EFFECTIVE_DOCS - NUM_DOCS_REMOVE + i);
- }
- TimePoint aft = std::chrono::steady_clock::now();
- fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef));
-
- timing_nns(name, nns, sk_list);
- fprintf(stderr, "Quality for %s realistic build with %u documents:\n", name, EFFECTIVE_DOCS);
- quality_nns(nns, sk_list);
-#endif
-
-#if 1
TimePoint bef = std::chrono::steady_clock::now();
for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) {
nns.addDoc(i);
}
TimePoint aft = std::chrono::steady_clock::now();
fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef));
-
timing_nns(name, nns, sk_list);
- fprintf(stderr, "Quality for %s clean build with %u documents:\n", name, EFFECTIVE_DOCS);
+ fprintf(stderr, "Quality for %s [A] clean build with %u documents:\n", name, EFFECTIVE_DOCS);
quality_nns(nns, sk_list);
-
bef = std::chrono::steady_clock::now();
for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
nns.addDoc(EFFECTIVE_DOCS + i);
@@ -379,111 +105,115 @@ void benchmark_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list
aft = std::chrono::steady_clock::now();
fprintf(stderr, "build %s index add then remove %u docs: %.3f ms\n",
name, NUM_DOCS_REMOVE, to_ms(aft - bef));
-
timing_nns(name, nns, sk_list);
- fprintf(stderr, "Quality for %s remove-damaged build with %u documents:\n", name, EFFECTIVE_DOCS);
+ fprintf(stderr, "Quality for %s [B] remove-damaged build with %u documents:\n", name, EFFECTIVE_DOCS);
quality_nns(nns, sk_list);
-#endif
+}
-#if 0
+template <typename FUNC>
+void bm_nns_remove_old(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ std::unique_ptr<NNS_API> nnsp = creator();
+ NNS_API &nns = *nnsp;
TimePoint bef = std::chrono::steady_clock::now();
+ for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
+ nns.addDoc(EFFECTIVE_DOCS + i);
+ }
for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) {
nns.addDoc(i);
}
+ for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
+ nns.removeDoc(EFFECTIVE_DOCS + i);
+ }
TimePoint aft = std::chrono::steady_clock::now();
fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef));
-
timing_nns(name, nns, sk_list);
- fprintf(stderr, "Quality for %s clean build with %u documents:\n", name, EFFECTIVE_DOCS);
+ fprintf(stderr, "Quality for %s [C] remove-oldest build with %u documents:\n", name, EFFECTIVE_DOCS);
quality_nns(nns, sk_list);
+}
- bef = std::chrono::steady_clock::now();
- for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) {
- nns.removeDoc(i);
- }
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "build %s index removed %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef));
-
- const uint32_t addFirst = NUM_DOCS - (NUM_DOCS_REMOVE * 3);
- const uint32_t addSecond = NUM_DOCS - (NUM_DOCS_REMOVE * 2);
-
- bef = std::chrono::steady_clock::now();
- for (uint32_t i = 0; i < addFirst; ++i) {
- nns.addDoc(i);
- }
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, addFirst, to_ms(aft - bef));
-
- bef = std::chrono::steady_clock::now();
+template <typename FUNC>
+void bm_nns_interleave(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ std::unique_ptr<NNS_API> nnsp = creator();
+ NNS_API &nns = *nnsp;
+ TimePoint bef = std::chrono::steady_clock::now();
for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
nns.addDoc(EFFECTIVE_DOCS + i);
- nns.addDoc(addFirst + i);
}
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "build %s index added %u docs: %.3f ms\n",
- name, 2 * NUM_DOCS_REMOVE, to_ms(aft - bef));
-
- bef = std::chrono::steady_clock::now();
+ for (uint32_t i = 0; i < EFFECTIVE_DOCS - NUM_DOCS_REMOVE; ++i) {
+ nns.addDoc(i);
+ }
for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
nns.removeDoc(EFFECTIVE_DOCS + i);
- nns.addDoc(addSecond + i);
+ nns.addDoc(EFFECTIVE_DOCS - NUM_DOCS_REMOVE + i);
}
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "build %s index added %u and removed %u docs: %.3f ms\n",
- name, NUM_DOCS_REMOVE, NUM_DOCS_REMOVE, to_ms(aft - bef));
-
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef));
timing_nns(name, nns, sk_list);
- fprintf(stderr, "Quality for %s with %u documents some churn:\n", name, EFFECTIVE_DOCS);
+ fprintf(stderr, "Quality for %s [D] realistic build with %u documents:\n", name, EFFECTIVE_DOCS);
quality_nns(nns, sk_list);
+}
-#endif
-
-#if 0
- bef = std::chrono::steady_clock::now();
- fprintf(stderr, "removing and adding %u documents...\n", EFFECTIVE_DOCS);
- for (uint32_t i = 0; i < EFFECTIVE_DOCS; ++i) {
- nns.removeDoc(i);
+template <typename FUNC>
+void bm_nns_remove_old_add_new(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ std::unique_ptr<NNS_API> nnsp = creator();
+ NNS_API &nns = *nnsp;
+ TimePoint bef = std::chrono::steady_clock::now();
+ for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
+ nns.addDoc(EFFECTIVE_DOCS + i);
+ }
+ for (uint32_t i = 0; i < EFFECTIVE_DOCS - NUM_DOCS_REMOVE; ++i) {
nns.addDoc(i);
}
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "build %s index rem/add %u docs: %.3f ms\n",
- name, EFFECTIVE_DOCS, to_ms(aft - bef));
-
+ for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
+ nns.removeDoc(EFFECTIVE_DOCS + i);
+ }
+ for (uint32_t i = 0; i < NUM_DOCS_REMOVE; ++i) {
+ nns.addDoc(EFFECTIVE_DOCS - NUM_DOCS_REMOVE + i);
+ }
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "build %s index with %u docs: %.3f ms\n", name, EFFECTIVE_DOCS, to_ms(aft - bef));
timing_nns(name, nns, sk_list);
- fprintf(stderr, "Quality for %s with %u documents full churn:\n", name, EFFECTIVE_DOCS);
+ fprintf(stderr, "Quality for %s [E] remove old, add new build with %u documents:\n", name, EFFECTIVE_DOCS);
quality_nns(nns, sk_list);
-#endif
+}
+
+template <typename FUNC>
+void benchmark_nns(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
+ bm_nns_simple(name, creator, sk_list);
+ bm_nns_remove_old(name, creator, sk_list);
+ bm_nns_interleave(name, creator, sk_list);
+ bm_nns_remove_old_add_new(name, creator, sk_list);
}
#if 0
TEST("require that Locality Sensitive Hashing mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_rplsh_nns(NUM_DIMS, adapter);
- benchmark_nns("RPLSH", *nns, { 200, 1000 });
+ auto creator = [&adapter]() { return make_rplsh_nns(NUM_DIMS, adapter); };
+ benchmark_nns("RPLSH", creator, { 200, 1000 });
}
#endif
#if 0
TEST("require that Annoy via NNS api mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_annoy_nns(NUM_DIMS, adapter);
- benchmark_nns("Annoy", *nns, { 8000, 10000 });
+ auto creator = [&adapter]() { return make_annoy_nns(NUM_DIMS, adapter); };
+ benchmark_nns("Annoy", creator, { 8000, 10000 });
}
#endif
#if 1
TEST("require that HNSW via NNS api mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_hnsw_nns(NUM_DIMS, adapter);
- benchmark_nns("HNSW-like", *nns, { 100, 150, 200 });
+ auto creator = [&adapter]() { return make_hnsw_nns(NUM_DIMS, adapter); };
+ benchmark_nns("HNSW-like", creator, { 100, 150, 200 });
}
#endif
#if 0
TEST("require that HNSW wrapped api mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_hnsw_wrap(NUM_DIMS, adapter);
- benchmark_nns("HNSW-wrap", *nns, { 100, 150, 200 });
+ auto creator = [&adapter]() { return make_hnsw_wrap(NUM_DIMS, adapter); };
+ benchmark_nns("HNSW-wrap", creator, { 100, 150, 200 });
}
#endif
@@ -498,17 +228,21 @@ TEST("require that HNSW wrapped api mostly works") {
*/
int main(int argc, char **argv) {
TEST_MASTER.init(__FILE__);
- std::string gist_dir = ".";
- if (argc > 1) {
- gist_dir = argv[1];
+ std::string data_set = "gist";
+ std::string data_dir = ".";
+ if (argc > 2) {
+ data_set = argv[1];
+ data_dir = argv[2];
+ } else if (argc > 1) {
+ data_dir = argv[1];
} else {
char *home = getenv("HOME");
if (home) {
- gist_dir = home;
- gist_dir += "/gist";
+ data_dir = home;
+ data_dir += "/" + data_set;
}
}
- read_data(gist_dir);
+ read_data(data_dir, data_set);
TEST_RUN_ALL();
return (TEST_MASTER.fini() ? 0 : 1);
}
diff --git a/eval/src/tests/ann/sift_benchmark.cpp b/eval/src/tests/ann/sift_benchmark.cpp
index 022c9404f5d..b2fa66cd0f1 100644
--- a/eval/src/tests/ann/sift_benchmark.cpp
+++ b/eval/src/tests/ann/sift_benchmark.cpp
@@ -13,173 +13,56 @@
#define NUM_DIMS 128
#define NUM_DOCS 1000000
#define NUM_Q 1000
+#define NUM_REACH 10000
#include "doc_vector_access.h"
#include "nns.h"
#include "for-sift-hit.h"
#include "for-sift-top-k.h"
+#include "std-random.h"
+#include "time-util.h"
+#include "point-vector.h"
+#include "read-vecs.h"
+#include "bruteforce-nns.h"
-std::vector<TopK> bruteforceResults;
-std::vector<float> tmp_v(NUM_DIMS);
-
-struct PointVector {
- float v[NUM_DIMS];
- using ConstArr = vespalib::ConstArrayRef<float>;
- operator ConstArr() const { return ConstArr(v, NUM_DIMS); }
-};
-
-static PointVector *aligned_alloc(size_t num) {
- size_t sz = num * sizeof(PointVector);
- double mega_bytes = sz / (1024.0*1024.0);
- fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes);
- char *mem = (char *)malloc(sz + 512);
- mem += 512;
- size_t val = (size_t)mem;
- size_t unalign = val % 512;
- mem -= unalign;
- return reinterpret_cast<PointVector *>(mem);
-}
-
-static PointVector *generatedQueries = aligned_alloc(NUM_Q);
-static PointVector *generatedDocs = aligned_alloc(NUM_DOCS);
-
-struct DocVectorAdapter : public DocVectorAccess<float>
-{
- vespalib::ConstArrayRef<float> get(uint32_t docid) const override {
- ASSERT_TRUE(docid < NUM_DOCS);
- return generatedDocs[docid];
- }
-};
-
-double computeDistance(const PointVector &query, uint32_t docid) {
- const PointVector &docvector = generatedDocs[docid];
- return l2distCalc.l2sq_dist(query, docvector, tmp_v);
-}
-
-void read_queries(std::string fn) {
- int fd = open(fn.c_str(), O_RDONLY);
- ASSERT_TRUE(fd > 0);
- int d;
- size_t rv;
- fprintf(stderr, "reading %u queries from %s\n", NUM_Q, fn.c_str());
- for (uint32_t qid = 0; qid < NUM_Q; ++qid) {
- rv = read(fd, &d, 4);
- ASSERT_EQUAL(rv, 4u);
- ASSERT_EQUAL(d, NUM_DIMS);
- rv = read(fd, &generatedQueries[qid].v, NUM_DIMS*sizeof(float));
- ASSERT_EQUAL(rv, sizeof(PointVector));
- }
- close(fd);
-}
-
-void read_docs(std::string fn) {
- int fd = open(fn.c_str(), O_RDONLY);
- ASSERT_TRUE(fd > 0);
- int d;
- size_t rv;
- fprintf(stderr, "reading %u doc vectors from %s\n", NUM_DOCS, fn.c_str());
- for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) {
- rv = read(fd, &d, 4);
- ASSERT_EQUAL(rv, 4u);
- ASSERT_EQUAL(d, NUM_DIMS);
- rv = read(fd, &generatedDocs[docid].v, NUM_DIMS*sizeof(float));
- ASSERT_EQUAL(rv, sizeof(PointVector));
- }
- close(fd);
-}
-
-using TimePoint = std::chrono::steady_clock::time_point;
-using Duration = std::chrono::steady_clock::duration;
-
-double to_ms(Duration elapsed) {
- std::chrono::duration<double, std::milli> ms(elapsed);
- return ms.count();
-}
-
-void read_data(const std::string& dir, const std::string& data_set) {
- fprintf(stderr, "read data set '%s' from directory '%s'\n", data_set.c_str(), dir.c_str());
- TimePoint bef = std::chrono::steady_clock::now();
- read_queries(dir + "/" + data_set + "_query.fvecs");
- TimePoint aft = std::chrono::steady_clock::now();
- fprintf(stderr, "read queries: %.3f ms\n", to_ms(aft - bef));
- bef = std::chrono::steady_clock::now();
- read_docs(dir + "/" + data_set + "_base.fvecs");
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "read docs: %.3f ms\n", to_ms(aft - bef));
-}
-
-
-struct BfHitComparator {
- bool operator() (const Hit &lhs, const Hit& rhs) const {
- if (lhs.distance < rhs.distance) return false;
- if (lhs.distance > rhs.distance) return true;
- return (lhs.docid > rhs.docid);
- }
-};
-
-class BfHitHeap {
-private:
- size_t _size;
- vespalib::PriorityQueue<Hit, BfHitComparator> _priQ;
-public:
- explicit BfHitHeap(size_t maxSize) : _size(maxSize), _priQ() {
- _priQ.reserve(maxSize);
- }
- ~BfHitHeap() {}
- void maybe_use(const Hit &hit) {
- if (_priQ.size() < _size) {
- _priQ.push(hit);
- } else if (hit.distance < _priQ.front().distance) {
- _priQ.front() = hit;
- _priQ.adjust();
- }
- }
- std::vector<Hit> bestHits() {
- std::vector<Hit> result;
- size_t i = _priQ.size();
- result.resize(i);
- while (i-- > 0) {
- result[i] = _priQ.front();
- _priQ.pop_front();
- }
- return result;
- }
-};
-
-TopK bruteforce_nns(const PointVector &query) {
+TopK bruteforce_nns_filter(const PointVector &query, const BitVector &blacklist) {
TopK result;
BfHitHeap heap(result.K);
for (uint32_t docid = 0; docid < NUM_DOCS; ++docid) {
+ if (blacklist.isSet(docid)) continue;
const PointVector &docvector = generatedDocs[docid];
- double d = l2distCalc.l2sq_dist(query, docvector, tmp_v);
+ double d = l2distCalc.l2sq_dist(query, docvector);
Hit h(docid, d);
heap.maybe_use(h);
}
std::vector<Hit> best = heap.bestHits();
+ EXPECT_EQUAL(best.size(), result.K);
for (size_t i = 0; i < result.K; ++i) {
result.hits[i] = best[i];
}
return result;
}
-void verifyBF(uint32_t qid) {
- const PointVector &query = generatedQueries[qid];
- TopK &result = bruteforceResults[qid];
- double min_distance = result.hits[0].distance;
- std::vector<double> all_c2;
- for (uint32_t i = 0; i < NUM_DOCS; ++i) {
- double dist = computeDistance(query, i);
- if (dist < min_distance) {
- fprintf(stderr, "WARN dist %.9g < mindist %.9g\n", dist, min_distance);
+void timing_bf_filter(int percent)
+{
+ BitVector blacklist(NUM_DOCS);
+ RndGen rnd;
+ for (uint32_t idx = 0; idx < NUM_DOCS; ++idx) {
+ if (rnd.nextUniform() < 0.01 * percent) {
+ blacklist.setBit(idx);
+ } else {
+ blacklist.clearBit(idx);
}
- EXPECT_FALSE(dist+0.000001 < min_distance);
- if (min_distance > 0) all_c2.push_back(dist / min_distance);
}
- if (all_c2.size() != NUM_DOCS) return;
- std::sort(all_c2.begin(), all_c2.end());
- for (uint32_t idx : { 1, 3, 10, 30, 100, 300, 1000, 3000, NUM_DOCS/2, NUM_DOCS-1}) {
- fprintf(stderr, "c2-factor[%u] = %.3f\n", idx, all_c2[idx]);
+ TimePoint bef = std::chrono::steady_clock::now();
+ for (int cnt = 0; cnt < NUM_Q; ++cnt) {
+ const PointVector &qv = generatedQueries[cnt];
+ auto res = bruteforce_nns_filter(qv, blacklist);
+ EXPECT_TRUE(res.hits[res.K - 1].distance > 0.0);
}
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "timing for bruteforce filter %d %%: %.3f ms = %.3f ms/q\n",
+ percent, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q);
}
TEST("require that brute force works") {
@@ -195,52 +78,90 @@ TEST("require that brute force works") {
for (int cnt = 0; cnt < NUM_Q; cnt = (cnt+1)*2) {
verifyBF(cnt);
}
+#if 1
+ for (uint32_t filter_percent : { 0, 1, 10, 50, 90, 95, 99 }) {
+ timing_bf_filter(filter_percent);
+ }
+#endif
}
using NNS_API = NNS<float>;
-TopK find_with_nns(uint32_t sk, NNS_API &nns, uint32_t qid) {
- TopK result;
+size_t search_with_filter(uint32_t sk, NNS_API &nns, uint32_t qid,
+ const BitVector &blacklist)
+{
const PointVector &qv = generatedQueries[qid];
vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
- auto rv = nns.topK(result.K, query, sk);
- for (size_t i = 0; i < result.K; ++i) {
- result.hits[i] = Hit(rv[i].docid, rv[i].sq.distance);
+ auto rv = nns.topKfilter(100, query, sk, blacklist);
+ return rv.size();
+}
+
+#include "find-with-nns.h"
+#include "verify-top-k.h"
+
+void verify_with_filter(uint32_t sk, NNS_API &nns, uint32_t qid,
+ const BitVector &blacklist)
+{
+ const PointVector &qv = generatedQueries[qid];
+ auto expected = bruteforce_nns_filter(qv, blacklist);
+ vespalib::ConstArrayRef<float> query(qv.v, NUM_DIMS);
+ auto rv = nns.topKfilter(expected.K, query, sk, blacklist);
+ TopK actual;
+ for (size_t i = 0; i < actual.K; ++i) {
+ actual.hits[i] = Hit(rv[i].docid, rv[i].sq.distance);
}
- return result;
+ verify_top_k(expected, actual, sk, qid);
}
-void verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) {
- TopK perfect = bruteforceResults[qid];
- TopK result = find_with_nns(sk, nns, qid);
- int recall = perfect.recall(result);
- EXPECT_TRUE(recall > 40);
- double sum_error = 0.0;
- double c_factor = 1.0;
- for (size_t i = 0; i < result.K; ++i) {
- double factor = (result.hits[i].distance / perfect.hits[i].distance);
- if (factor < 0.99 || factor > 25) {
- fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n",
- i, result.hits[i].distance, perfect.hits[i].distance);
+void timing_nns_filter(const char *name, NNS_API &nns,
+ std::vector<uint32_t> sk_list, int percent)
+{
+ BitVector blacklist(NUM_DOCS);
+ RndGen rnd;
+ for (uint32_t idx = 0; idx < NUM_DOCS; ++idx) {
+ if (rnd.nextUniform() < 0.01 * percent) {
+ blacklist.setBit(idx);
+ } else {
+ blacklist.clearBit(idx);
}
- sum_error += factor;
- c_factor = std::max(c_factor, factor);
}
- EXPECT_TRUE(c_factor < 1.5);
- fprintf(stderr, "quality sk=%u: query %u: recall %d, c2-factor %.3f, avg c2: %.3f\n",
- sk, qid, recall, c_factor, sum_error / result.K);
- if (qid == 6) {
- for (size_t i = 0; i < 10; ++i) {
- fprintf(stderr, "topk[%zu] BF{%u %.3f} index{%u %.3f}\n",
- i,
- perfect.hits[i].docid, perfect.hits[i].distance,
- result.hits[i].docid, result.hits[i].distance);
+ for (uint32_t search_k : sk_list) {
+ TimePoint bef = std::chrono::steady_clock::now();
+ for (int cnt = 0; cnt < NUM_Q; ++cnt) {
+ uint32_t nh = search_with_filter(search_k, nns, cnt, blacklist);
+ EXPECT_EQUAL(nh, 100u);
+ }
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "timing for %s filter %d %% search_k=%u: %.3f ms = %.3f ms/q\n",
+ name, percent, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q);
+#if 0
+ fprintf(stderr, "Quality check for %s filter %d %%:\n", name, percent);
+ for (int cnt = 0; cnt < NUM_Q; ++cnt) {
+ verify_with_filter(search_k, nns, cnt, blacklist);
}
+#endif
}
}
-void benchmark_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list) {
+void timing_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list) {
+ for (uint32_t search_k : sk_list) {
+ TimePoint bef = std::chrono::steady_clock::now();
+ for (int cnt = 0; cnt < NUM_Q; ++cnt) {
+ find_with_nns(search_k, nns, cnt);
+ }
+ TimePoint aft = std::chrono::steady_clock::now();
+ fprintf(stderr, "timing for %s search_k=%u: %.3f ms = %.3f ms/q\n",
+ name, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q);
+ }
+}
+
+#include "quality-nns.h"
+
+template <typename FUNC>
+void benchmark_nns(const char *name, FUNC creator, std::vector<uint32_t> sk_list) {
fprintf(stderr, "trying %s indexing...\n", name);
+ std::unique_ptr<NNS_API> nnsp = creator();
+ NNS_API &nns = *nnsp;
TimePoint bef = std::chrono::steady_clock::now();
for (uint32_t i = 0; i < NUM_DOCS; ++i) {
nns.addDoc(i);
@@ -250,50 +171,44 @@ void benchmark_nns(const char *name, NNS_API &nns, std::vector<uint32_t> sk_list
TimePoint aft = std::chrono::steady_clock::now();
fprintf(stderr, "build %s index: %.3f ms\n", name, to_ms(aft - bef));
- for (uint32_t search_k : sk_list) {
- bef = std::chrono::steady_clock::now();
- for (int cnt = 0; cnt < NUM_Q; ++cnt) {
- find_with_nns(search_k, nns, cnt);
- }
- aft = std::chrono::steady_clock::now();
- fprintf(stderr, "timing for %s search_k=%u: %.3f ms = %.3f ms/q\n",
- name, search_k, to_ms(aft - bef), to_ms(aft - bef)/NUM_Q);
- for (int cnt = 0; cnt < NUM_Q; ++cnt) {
- verify_nns_quality(search_k, nns, cnt);
- }
+ fprintf(stderr, "Timings for %s :\n", name);
+ timing_nns(name, nns, sk_list);
+ for (uint32_t filter_percent : { 0, 1, 10, 50, 90, 95, 99 }) {
+ timing_nns_filter(name, nns, sk_list, filter_percent);
}
+ fprintf(stderr, "Quality for %s :\n", name);
+ quality_nns(nns, sk_list);
}
-
-#if 1
+#if 0
TEST("require that Locality Sensitive Hashing mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_rplsh_nns(NUM_DIMS, adapter);
- benchmark_nns("RPLSH", *nns, { 200, 1000 });
+ auto creator = [&adapter]() { return make_rplsh_nns(NUM_DIMS, adapter); };
+ benchmark_nns("RPLSH", creator, { 200, 1000 });
}
#endif
#if 1
TEST("require that Annoy via NNS api mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_annoy_nns(NUM_DIMS, adapter);
- benchmark_nns("Annoy", *nns, { 8000, 10000 });
+ auto creator = [&adapter]() { return make_annoy_nns(NUM_DIMS, adapter); };
+ benchmark_nns("Annoy", creator, { 8000, 10000 });
}
#endif
#if 1
TEST("require that HNSW via NNS api mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_hnsw_nns(NUM_DIMS, adapter);
- benchmark_nns("HNSW-like", *nns, { 100, 150, 200 });
+ auto creator = [&adapter]() { return make_hnsw_nns(NUM_DIMS, adapter); };
+ benchmark_nns("HNSW-like", creator, { 100, 150, 200 });
}
#endif
#if 0
TEST("require that HNSW wrapped api mostly works") {
DocVectorAdapter adapter;
- std::unique_ptr<NNS_API> nns = make_hnsw_wrap(NUM_DIMS, adapter);
- benchmark_nns("HNSW-wrap", *nns, { 100, 150, 200 });
+ auto creator = [&adapter]() { return make_hnsw_wrap(NUM_DIMS, adapter); };
+ benchmark_nns("HNSW-wrap", creator, { 100, 150, 200 });
}
#endif
diff --git a/eval/src/tests/ann/time-util.h b/eval/src/tests/ann/time-util.h
new file mode 100644
index 00000000000..2f5c2bdd583
--- /dev/null
+++ b/eval/src/tests/ann/time-util.h
@@ -0,0 +1,9 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+using TimePoint = std::chrono::steady_clock::time_point;
+using Duration = std::chrono::steady_clock::duration;
+
+double to_ms(Duration elapsed) {
+ std::chrono::duration<double, std::milli> ms(elapsed);
+ return ms.count();
+}
diff --git a/eval/src/tests/ann/verify-top-k.h b/eval/src/tests/ann/verify-top-k.h
new file mode 100644
index 00000000000..220c273d017
--- /dev/null
+++ b/eval/src/tests/ann/verify-top-k.h
@@ -0,0 +1,27 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+int verify_top_k(const TopK &perfect, const TopK &result, uint32_t sk, uint32_t qid) {
+ int recall = perfect.recall(result);
+ EXPECT_TRUE(recall > 40);
+ double sum_error = 0.0;
+ double c_factor = 1.0;
+ for (size_t i = 0; i < result.K; ++i) {
+ double factor = (result.hits[i].distance / perfect.hits[i].distance);
+ if (factor < 0.99 || factor > 25) {
+ fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n",
+ i, result.hits[i].distance, perfect.hits[i].distance);
+ }
+ sum_error += factor;
+ c_factor = std::max(c_factor, factor);
+ }
+ EXPECT_TRUE(c_factor < 1.5);
+ fprintf(stderr, "quality sk=%u: query %u: recall %d c2-factor %.3f avg c2: %.3f\n",
+ sk, qid, recall, c_factor, sum_error / result.K);
+ return recall;
+}
+
+int verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) {
+ TopK perfect = bruteforceResults[qid];
+ TopK result = find_with_nns(sk, nns, qid);
+ return verify_top_k(perfect, result, sk, qid);
+}
diff --git a/eval/src/tests/ann/xp-annoy-nns.cpp b/eval/src/tests/ann/xp-annoy-nns.cpp
index f022aae5974..213e583d95a 100644
--- a/eval/src/tests/ann/xp-annoy-nns.cpp
+++ b/eval/src/tests/ann/xp-annoy-nns.cpp
@@ -27,6 +27,7 @@ struct Node {
virtual Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) = 0;
virtual int remove(uint32_t docid, V vector) = 0;
virtual void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const = 0;
+ virtual void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const = 0;
virtual void stats(std::vector<uint32_t> &depths) = 0;
};
@@ -38,6 +39,7 @@ struct LeafNode : public Node {
Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override;
int remove(uint32_t docid, V vector) override;
void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const override;
+ void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override;
Node *split(AnnoyLikeNns &meta);
virtual void stats(std::vector<uint32_t> &depths) override { depths.push_back(1); }
@@ -55,6 +57,7 @@ struct SplitNode : public Node {
Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override;
int remove(uint32_t docid, V vector) override;
void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const override;
+ void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override;
double planeDistance(V vector) const;
virtual void stats(std::vector<uint32_t> &depths) override {
@@ -106,6 +109,8 @@ public:
}
std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override;
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override;
+
V getVector(uint32_t docid) const { return _dva.get(docid); }
double uniformRnd() { return _rndGen.nextUniform(); }
uint32_t dims() const { return _numDims; }
@@ -304,6 +309,16 @@ LeafNode::findCandidates(std::set<uint32_t> &cands, V, NodeQueue &, double) cons
}
}
+void
+LeafNode::filterCandidates(std::set<uint32_t> &cands, V, NodeQueue &, double, const BitVector &blacklist) const
+{
+ for (uint32_t d : docids) {
+ if (blacklist.isSet(d)) continue;
+ cands.insert(d);
+ }
+}
+
+
SplitNode::~SplitNode()
{
delete leftChildren;
@@ -344,6 +359,15 @@ SplitNode::findCandidates(std::set<uint32_t> &, V vector, NodeQueue &queue, doub
queue.push(std::make_pair(std::min(d, minDist), rightChildren));
}
+void
+SplitNode::filterCandidates(std::set<uint32_t> &, V vector, NodeQueue &queue, double minDist, const BitVector &) const
+{
+ double d = planeDistance(vector);
+ // fprintf(stderr, "push 2 nodes dist %g\n", d);
+ queue.push(std::make_pair(std::min(-d, minDist), leftChildren));
+ queue.push(std::make_pair(std::min(d, minDist), rightChildren));
+}
+
std::vector<NnsHit>
AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k)
{
@@ -387,6 +411,40 @@ AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k)
return r;
}
+std::vector<NnsHit>
+AnnoyLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ ++find_top_k_cnt;
+ std::vector<NnsHit> r;
+ r.reserve(k);
+ std::set<uint32_t> candidates;
+ NodeQueue queue;
+ for (Node *root : _roots) {
+ double dist = std::numeric_limits<double>::max();
+ queue.push(std::make_pair(dist, root));
+ }
+ while ((candidates.size() < std::max(k, search_k)) && (queue.size() > 0)) {
+ const QueueNode& top = queue.top();
+ double md = top.first;
+ // fprintf(stderr, "find candidates: node with min distance %g\n", md);
+ Node *n = top.second;
+ queue.pop();
+ n->filterCandidates(candidates, vector, queue, md, blacklist);
+ ++find_cand_cnt;
+ }
+ for (uint32_t docid : candidates) {
+ if (blacklist.isSet(docid)) continue;
+ double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid));
+ NnsHit hit(docid, SqDist(dist));
+ r.push_back(hit);
+ }
+ std::sort(r.begin(), r.end(), NnsHitComparatorLessDistance());
+ while (r.size() > k) r.pop_back();
+ return r;
+}
+
+
+
void
AnnoyLikeNns::dumpStats() {
fprintf(stderr, "stats for AnnoyLikeNns:\n");
diff --git a/eval/src/tests/ann/xp-hnsw-wrap.cpp b/eval/src/tests/ann/xp-hnsw-wrap.cpp
index 3eb01142dcd..45c7a974254 100644
--- a/eval/src/tests/ann/xp-hnsw-wrap.cpp
+++ b/eval/src/tests/ann/xp-hnsw-wrap.cpp
@@ -46,6 +46,34 @@ public:
}
return result;
}
+
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override {
+ std::vector<NnsHit> reversed;
+ uint32_t adjusted_k = k+4;
+ uint32_t adjusted_sk = search_k+4;
+ for (int retry = 0; (retry < 5) && (reversed.size() < k); ++retry) {
+ reversed.clear();
+ _hnsw.setEf(adjusted_sk);
+ auto priQ = _hnsw.searchKnn(vector.cbegin(), adjusted_k);
+ while (! priQ.empty()) {
+ auto pair = priQ.top();
+ if (! blacklist.isSet(pair.second)) {
+ reversed.emplace_back(pair.second, SqDist(pair.first));
+ }
+ priQ.pop();
+ }
+ double got = 1 + reversed.size();
+ double factor = 1.25 * k / got;
+ adjusted_k *= factor;
+ adjusted_sk *= factor;
+ }
+ std::vector<NnsHit> result;
+ while (result.size() < k && !reversed.empty()) {
+ result.push_back(reversed.back());
+ reversed.pop_back();
+ }
+ return result;
+ }
};
std::unique_ptr<NNS<float>>
diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp
index 5cdbdd8efa3..b7cae9f731c 100644
--- a/eval/src/tests/ann/xp-hnswlike-nns.cpp
+++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp
@@ -1,11 +1,6 @@
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <algorithm>
-#include <assert.h>
-#include <queue>
-#include <cinttypes>
-#include "std-random.h"
-#include "nns.h"
+#include "hnsw-like.h"
/*
Todo:
@@ -32,356 +27,223 @@ static size_t distcalls_heuristic;
static size_t distcalls_shrink;
static size_t distcalls_refill;
static size_t refill_needed_calls;
-
-struct LinkList : std::vector<uint32_t>
-{
- bool has_link_to(uint32_t id) const {
- auto iter = std::find(begin(), end(), id);
- return (iter != end());
- }
- void remove_link(uint32_t id) {
- uint32_t last = back();
- for (iterator iter = begin(); iter != end(); ++iter) {
- if (*iter == id) {
- *iter = last;
- pop_back();
- return;
- }
- }
- fprintf(stderr, "BAD missing link to remove: %u\n", id);
- abort();
- }
-};
-
-struct Node {
- std::vector<LinkList> _links;
- Node(uint32_t , uint32_t numLevels, uint32_t M)
- : _links(numLevels)
- {
- for (uint32_t i = 0; i < _links.size(); ++i) {
- _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1));
- }
- }
-};
-
-struct VisitedSet
+static size_t shrink_needed_calls;
+static size_t disconnected_weak_links;
+static size_t disconnected_for_symmetry;
+static size_t select_n_full;
+static size_t select_n_partial;
+
+
+
+HnswLikeNns::HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _nodes(),
+ _entryId(0),
+ _entryLevel(-1),
+ _M(16),
+ _efConstruction(200),
+ _levelMultiplier(1.0 / log(1.0 * _M)),
+ _rndGen(),
+ _ops_counter(0)
{
- using Mark = unsigned short;
- Mark *ptr;
- Mark curval;
- size_t sz;
- VisitedSet(const VisitedSet &) = delete;
- VisitedSet& operator=(const VisitedSet &) = delete;
- explicit VisitedSet(size_t size) {
- ptr = (Mark *)malloc(size * sizeof(Mark));
- curval = -1;
- sz = size;
- }
- void clear() {
- ++curval;
- if (curval == 0) {
- memset(ptr, 0, sz * sizeof(Mark));
- ++curval;
- }
- }
- ~VisitedSet() { free(ptr); }
- void mark(size_t id) { ptr[id] = curval; }
- bool isMarked(size_t id) const { return ptr[id] == curval; }
-};
-
-struct VisitedSetPool
-{
- std::unique_ptr<VisitedSet> lastUsed;
- VisitedSetPool() {
- lastUsed = std::make_unique<VisitedSet>(250);
- }
- ~VisitedSetPool() {}
- VisitedSet &get(size_t size) {
- if (size > lastUsed->sz) {
- lastUsed = std::make_unique<VisitedSet>(size*2);
- }
- lastUsed->clear();
- return *lastUsed;
- }
-};
-
-struct HnswHit {
- double dist;
- uint32_t docid;
- HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {}
-};
-
-struct GreaterDist {
- bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
- return (rhs.dist < lhs.dist);
- }
-};
-struct LesserDist {
- bool operator() (const HnswHit &lhs, const HnswHit& rhs) const {
- return (lhs.dist < rhs.dist);
- }
-};
-
-using NearestList = std::vector<HnswHit>;
-
-struct NearestPriQ : std::priority_queue<HnswHit, NearestList, GreaterDist>
-{
-};
-
-struct FurthestPriQ : std::priority_queue<HnswHit, NearestList, LesserDist>
-{
- NearestList steal() {
- NearestList result;
- c.swap(result);
- return result;
- }
- const NearestList& peek() const { return c; }
-};
-
-class HnswLikeNns : public NNS<float>
-{
-private:
- std::vector<Node> _nodes;
- uint32_t _entryId;
- int _entryLevel;
- uint32_t _M;
- uint32_t _efConstruction;
- double _levelMultiplier;
- RndGen _rndGen;
- VisitedSetPool _visitedSetPool;
- size_t _ops_counter;
-
- double distance(Vector v, uint32_t id) const;
-
- double distance(uint32_t a, uint32_t b) const {
- Vector v = _dva.get(a);
- return distance(v, b);
- }
-
- int randomLevel() {
- double unif = _rndGen.nextUniform();
- double r = -log(1.0-unif) * _levelMultiplier;
- return (int) r;
- }
-
- uint32_t count_reachable() const;
- void dumpStats() const;
-
-public:
- HnswLikeNns(uint32_t numDims, const DocVectorAccess<float> &dva)
- : NNS(numDims, dva),
- _nodes(),
- _entryId(0),
- _entryLevel(-1),
- _M(16),
- _efConstruction(200),
- _levelMultiplier(1.0 / log(1.0 * _M)),
- _rndGen(),
- _ops_counter(0)
- {
- }
-
- ~HnswLikeNns() { dumpStats(); }
-
- LinkList& getLinkList(uint32_t docid, uint32_t level) {
- // assert(docid < _nodes.size());
- // assert(level < _nodes[docid]._links.size());
- return _nodes[docid]._links[level];
- }
-
- const LinkList& getLinkList(uint32_t docid, uint32_t level) const {
- return _nodes[docid]._links[level];
- }
+}
- // simple greedy search
- HnswHit search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
- bool keepGoing = true;
- while (keepGoing) {
- keepGoing = false;
- const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
- for (uint32_t n_id : neighbors) {
- double dist = distance(vector, n_id);
- ++distcalls_simple;
- if (dist < curPoint.dist) {
- curPoint = HnswHit(n_id, SqDist(dist));
- keepGoing = true;
- }
+// simple greedy search
+HnswHit
+HnswLikeNns::search_layer_simple(Vector vector, HnswHit curPoint, uint32_t searchLevel) {
+ bool keepGoing = true;
+ while (keepGoing) {
+ keepGoing = false;
+ const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel);
+ for (uint32_t n_id : neighbors) {
+ double dist = distance(vector, n_id);
+ ++distcalls_simple;
+ if (dist < curPoint.dist) {
+ curPoint = HnswHit(n_id, SqDist(dist));
+ keepGoing = true;
}
}
- return curPoint;
}
+ return curPoint;
+}
- void search_layer(Vector vector, FurthestPriQ &w,
- uint32_t ef, uint32_t searchLevel);
-
- bool haveCloserDistance(HnswHit e, const LinkList &r) const {
- for (uint32_t prevId : r) {
- double dist = distance(e.docid, prevId);
- ++distcalls_heuristic;
- if (dist < e.dist) return true;
- }
- return false;
+bool
+HnswLikeNns::haveCloserDistance(HnswHit e, const LinkList &r) const {
+ for (uint32_t prevId : r) {
+ double dist = distance(e.docid, prevId);
+ ++distcalls_heuristic;
+ if (dist < e.dist) return true;
}
+ return false;
+}
- LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) const;
-
- LinkList remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &removed) const;
-
- void addDoc(uint32_t docid) override {
- Vector vector = _dva.get(docid);
- for (uint32_t id = _nodes.size(); id <= docid; ++id) {
- _nodes.emplace_back(id, 0, _M);
- }
- int level = randomLevel();
- assert(_nodes[docid]._links.size() == 0);
- _nodes[docid] = Node(docid, level+1, _M);
- if (_entryLevel < 0) {
- _entryId = docid;
- _entryLevel = level;
- track_ops();
- return;
- }
- int searchLevel = _entryLevel;
- double entryDist = distance(vector, _entryId);
- ++distcalls_other;
- HnswHit entryPoint(_entryId, SqDist(entryDist));
- while (searchLevel > level) {
- entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
- --searchLevel;
- }
- searchLevel = std::min(level, _entryLevel);
- FurthestPriQ w;
- w.push(entryPoint);
- while (searchLevel >= 0) {
- search_layer(vector, w, _efConstruction, searchLevel);
- LinkList neighbors = select_neighbors(w.peek(), _M);
- connect_new_node(docid, neighbors, searchLevel);
- each_shrink_ifneeded(neighbors, searchLevel);
- --searchLevel;
- }
- if (level > _entryLevel) {
- _entryLevel = level;
- _entryId = docid;
- }
+void
+HnswLikeNns::addDoc(uint32_t docid) {
+ Vector vector = _dva.get(docid);
+ for (uint32_t id = _nodes.size(); id <= docid; ++id) {
+ _nodes.emplace_back(id, 0, _M);
+ }
+ int level = randomLevel();
+ assert(_nodes[docid]._links.size() == 0);
+ _nodes[docid] = Node(docid, level+1, _M);
+ if (_entryLevel < 0) {
+ _entryId = docid;
+ _entryLevel = level;
track_ops();
- }
-
- void track_ops() {
- _ops_counter++;
- if ((_ops_counter % 10000) == 0) {
- double div = _ops_counter;
- fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
- fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
- fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
- fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
- fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
- fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
- fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
- fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
- }
- }
+ return;
+ }
+ int searchLevel = _entryLevel;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ while (searchLevel > level) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ searchLevel = std::min(level, _entryLevel);
+ FurthestPriQ w;
+ w.push(entryPoint);
+ while (searchLevel >= 0) {
+ search_layer(vector, w, _efConstruction, searchLevel);
+ LinkList neighbors = select_neighbors(w.peek(), _M);
+ connect_new_node(docid, neighbors, searchLevel);
+ each_shrink_ifneeded(neighbors, searchLevel);
+ --searchLevel;
+ }
+ if (level > _entryLevel) {
+ _entryLevel = level;
+ _entryId = docid;
+ }
+ track_ops();
+}
- void remove_link_from(uint32_t from_id, uint32_t remove_id, uint32_t level) {
- LinkList &links = getLinkList(from_id, level);
- links.remove_link(remove_id);
- }
+void
+HnswLikeNns::track_ops() {
+ _ops_counter++;
+ if ((_ops_counter % 10000) == 0) {
+ double div = _ops_counter;
+ fprintf(stderr, "add / remove ops: %zu\n", _ops_counter);
+ fprintf(stderr, "distance calls for layer: %zu is %.3f per op\n", distcalls_search_layer, distcalls_search_layer/ div);
+ fprintf(stderr, "distance calls for heuristic: %zu is %.3f per op\n", distcalls_heuristic, distcalls_heuristic / div);
+ fprintf(stderr, "distance calls for simple: %zu is %.3f per op\n", distcalls_simple, distcalls_simple / div);
+ fprintf(stderr, "distance calls for shrink: %zu is %.3f per op\n", distcalls_shrink, distcalls_shrink / div);
+ fprintf(stderr, "distance calls for refill: %zu is %.3f per op\n", distcalls_refill, distcalls_refill / div);
+ fprintf(stderr, "distance calls for other: %zu is %.3f per op\n", distcalls_other, distcalls_other / div);
+ fprintf(stderr, "refill needed calls: %zu is %.3f per op\n", refill_needed_calls, refill_needed_calls / div);
+ fprintf(stderr, "shrink needed calls: %zu is %.3f per op\n", shrink_needed_calls, shrink_needed_calls / div);
+ fprintf(stderr, "disconnected weak links: %zu is %.3f per op\n", disconnected_weak_links, disconnected_weak_links / div);
+ fprintf(stderr, "disconnected for symmetry: %zu is %.3f per op\n", disconnected_for_symmetry, disconnected_for_symmetry / div);
+ fprintf(stderr, "select neighbors: partial %zu vs full %zu\n", select_n_partial, select_n_full);
+ }
+}
- void refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
- LinkList &my_links = getLinkList(my_id, level);
- if (my_links.size() < 8) {
- ++refill_needed_calls;
- for (uint32_t repl_id : replacements) {
- if (repl_id == my_id) continue;
- if (my_links.has_link_to(repl_id)) continue;
- LinkList &other_links = getLinkList(repl_id, level);
- if (other_links.size() + 1 >= _M) continue;
- other_links.push_back(my_id);
- my_links.push_back(repl_id);
- if (my_links.size() >= _M) return;
- }
+void
+HnswLikeNns::refill_ifneeded(uint32_t my_id, const LinkList &replacements, uint32_t level) {
+ LinkList &my_links = getLinkList(my_id, level);
+ if (my_links.size() < 8) {
+ ++refill_needed_calls;
+ for (uint32_t repl_id : replacements) {
+ if (repl_id == my_id) continue;
+ if (my_links.has_link_to(repl_id)) continue;
+ LinkList &other_links = getLinkList(repl_id, level);
+ if (other_links.size() + 1 >= _M) continue;
+ other_links.push_back(my_id);
+ my_links.push_back(repl_id);
+ if (my_links.size() >= _M) return;
}
}
+}
- void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level);
-
- void shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
- LinkList &links = getLinkList(shrink_id, level);
- NearestList distances;
- for (uint32_t n_id : links) {
- double n_dist = distance(shrink_id, n_id);
- ++distcalls_shrink;
- distances.emplace_back(n_id, SqDist(n_dist));
- }
- LinkList lostLinks;
- LinkList oldLinks = links;
- links = remove_weakest(distances, maxLinks, lostLinks);
- for (uint32_t lost_id : lostLinks) {
- remove_link_from(lost_id, shrink_id, level);
- refill_ifneeded(lost_id, oldLinks, level);
- }
+void
+HnswLikeNns::shrink_links(uint32_t shrink_id, uint32_t maxLinks, uint32_t level) {
+ LinkList &links = getLinkList(shrink_id, level);
+ NearestList distances;
+ for (uint32_t n_id : links) {
+ double n_dist = distance(shrink_id, n_id);
+ ++distcalls_shrink;
+ distances.emplace_back(n_id, SqDist(n_dist));
+ }
+ LinkList lostLinks;
+ LinkList oldLinks = links;
+ links = remove_weakest(distances, maxLinks, lostLinks);
+#define KEEP_SYM
+#ifdef KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
+ ++disconnected_for_symmetry;
+ remove_link_from(lost_id, shrink_id, level);
+ }
+#define DO_REFILL_AFTER_KEEP_SYM
+#ifdef DO_REFILL_AFTER_KEEP_SYM
+ for (uint32_t lost_id : lostLinks) {
+ refill_ifneeded(lost_id, oldLinks, level);
}
+#endif
+#endif
+}
- void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level);
-
- void removeDoc(uint32_t docid) override {
- Node &node = _nodes[docid];
- bool need_new_entrypoint = (docid == _entryId);
- for (int level = node._links.size(); level-- > 0; ) {
- LinkList my_links;
- my_links.swap(node._links[level]);
- for (uint32_t n_id : my_links) {
- if (need_new_entrypoint) {
- _entryId = n_id;
- _entryLevel = level;
- need_new_entrypoint = false;
- }
- remove_link_from(n_id, docid, level);
- }
- for (uint32_t n_id : my_links) {
- refill_ifneeded(n_id, my_links, level);
+void
+HnswLikeNns::removeDoc(uint32_t docid) {
+ Node &node = _nodes[docid];
+ bool need_new_entrypoint = (docid == _entryId);
+ for (int level = node._links.size(); level-- > 0; ) {
+ LinkList my_links;
+ my_links.swap(node._links[level]);
+ for (uint32_t n_id : my_links) {
+ if (need_new_entrypoint) {
+ _entryId = n_id;
+ _entryLevel = level;
+ need_new_entrypoint = false;
}
- }
- node = Node(docid, 0, _M);
- if (need_new_entrypoint) {
- _entryLevel = -1;
- _entryId = 0;
- for (uint32_t i = 0; i < _nodes.size(); ++i) {
- if (_nodes[i]._links.size() > 0) {
- _entryId = i;
- _entryLevel = _nodes[i]._links.size() - 1;
- break;
- }
+ remove_link_from(n_id, docid, level);
+ }
+ while (! my_links.empty()) {
+ uint32_t n_id = my_links.back();
+ my_links.pop_back();
+ refill_ifneeded(n_id, my_links, level);
+ }
+ }
+ node = Node(docid, 0, _M);
+ if (need_new_entrypoint) {
+ _entryLevel = -1;
+ _entryId = 0;
+ for (uint32_t i = 0; i < _nodes.size(); ++i) {
+ if (_nodes[i]._links.size() > 0) {
+ _entryId = i;
+ _entryLevel = _nodes[i]._links.size() - 1;
+ break;
}
}
- track_ops();
}
+ track_ops();
+}
- std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
- std::vector<NnsHit> result;
- if (_entryLevel < 0) return result;
- double entryDist = distance(vector, _entryId);
- ++distcalls_other;
- HnswHit entryPoint(_entryId, SqDist(entryDist));
- int searchLevel = _entryLevel;
- FurthestPriQ w;
- w.push(entryPoint);
- while (searchLevel > 0) {
- search_layer(vector, w, std::min(k, search_k), searchLevel);
- --searchLevel;
- }
- search_layer(vector, w, std::max(k, search_k), 0);
- while (w.size() > k) {
- w.pop();
- }
- NearestList tmp = w.steal();
- std::sort(tmp.begin(), tmp.end(), LesserDist());
- result.reserve(tmp.size());
- for (const auto & hit : tmp) {
- result.emplace_back(hit.docid, SqDist(hit.dist));
- }
- return result;
+std::vector<NnsHit>
+HnswLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) {
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+ search_layer(vector, w, std::max(k, search_k), 0);
+ while (w.size() > k) {
+ w.pop();
+ }
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(tmp.size());
+ for (const auto & hit : tmp) {
+ result.emplace_back(hit.docid, SqDist(hit.dist));
}
-};
+ return result;
+}
+
double
HnswLikeNns::distance(Vector v, uint32_t b) const
@@ -390,12 +252,40 @@ HnswLikeNns::distance(Vector v, uint32_t b) const
return l2distCalc.l2sq_dist(v, w);
}
+std::vector<NnsHit>
+HnswLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ std::vector<NnsHit> result;
+ if (_entryLevel < 0) return result;
+ double entryDist = distance(vector, _entryId);
+ ++distcalls_other;
+ HnswHit entryPoint(_entryId, SqDist(entryDist));
+ int searchLevel = _entryLevel;
+ while (searchLevel > 0) {
+ entryPoint = search_layer_simple(vector, entryPoint, searchLevel);
+ --searchLevel;
+ }
+ FurthestPriQ w;
+ w.push(entryPoint);
+ search_layer_with_filter(vector, w, std::max(k, search_k), 0, blacklist);
+ NearestList tmp = w.steal();
+ std::sort(tmp.begin(), tmp.end(), LesserDist());
+ result.reserve(std::min((size_t)k, tmp.size()));
+ for (const auto & hit : tmp) {
+ if (blacklist.isSet(hit.docid)) continue;
+ result.emplace_back(hit.docid, SqDist(hit.dist));
+ if (result.size() == k) break;
+ }
+ return result;
+}
+
void
HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) {
uint32_t maxLinks = (level > 0) ? _M : (2 * _M);
for (uint32_t old_id : neighbors) {
LinkList &oldLinks = getLinkList(old_id, level);
if (oldLinks.size() > maxLinks) {
+ ++shrink_needed_calls;
shrink_links(old_id, maxLinks, level);
}
}
@@ -437,6 +327,44 @@ HnswLikeNns::search_layer(Vector vector, FurthestPriQ &w,
return;
}
+void
+HnswLikeNns::search_layer_with_filter(Vector vector, FurthestPriQ &w,
+ uint32_t ef, uint32_t searchLevel,
+ const BitVector &blacklist)
+{
+ NearestPriQ candidates;
+ VisitedSet &visited = _visitedSetPool.get(_nodes.size());
+
+ for (const HnswHit & entry : w.peek()) {
+ candidates.push(entry);
+ visited.mark(entry.docid);
+ if (blacklist.isSet(entry.docid)) ++ef;
+ }
+ double limd = std::numeric_limits<double>::max();
+ while (! candidates.empty()) {
+ HnswHit cand = candidates.top();
+ if (cand.dist > limd) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) {
+ if (visited.isMarked(e_id)) continue;
+ visited.mark(e_id);
+ double e_dist = distance(vector, e_id);
+ ++distcalls_search_layer;
+ if (e_dist < limd) {
+ candidates.emplace(e_id, SqDist(e_dist));
+ if (blacklist.isSet(e_id)) continue;
+ w.emplace(e_id, SqDist(e_dist));
+ if (w.size() > ef) {
+ w.pop();
+ limd = w.top().dist;
+ }
+ }
+ }
+ }
+}
+
LinkList
HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkList &lost) const
{
@@ -458,13 +386,13 @@ HnswLikeNns::remove_weakest(const NearestList &neighbors, uint32_t curMax, LinkL
return result;
}
+#define NO_BACKFILL
#ifdef NO_BACKFILL
LinkList
HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) const
{
LinkList result;
result.reserve(curMax+1);
- bool needFiltering = (neighbors.size() > curMax);
NearestPriQ w;
for (const auto & entry : neighbors) {
w.push(entry);
@@ -472,12 +400,16 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con
while (! w.empty()) {
HnswHit e = w.top();
w.pop();
- if (needFiltering && haveCloserDistance(e, result)) {
+ if (haveCloserDistance(e, result)) {
continue;
}
result.push_back(e.docid);
- if (result.size() == curMax) return result;
+ if (result.size() == curMax) {
+ ++select_n_full;
+ return result;
+ }
}
+ ++select_n_partial;
return result;
}
#else
@@ -502,10 +434,10 @@ HnswLikeNns::select_neighbors(const NearestList &neighbors, uint32_t curMax) con
result.push_back(e.docid);
if (result.size() == curMax) return result;
}
- if (result.size() * 4 < curMax) {
+ if (result.size() * 4 < _M) {
for (uint32_t fill_id : backfill) {
result.push_back(fill_id);
- if (result.size() * 4 >= curMax) break;
+ if (result.size() * 2 >= _M) break;
}
}
return result;
@@ -576,7 +508,9 @@ HnswLikeNns::dumpStats() const {
for (uint32_t n_id : link_list) {
const LinkList &neigh_list = getLinkList(n_id, 0);
if (! neigh_list.has_link_to(id)) {
+#ifdef KEEP_SYM
fprintf(stderr, "BAD: %u has link to neighbor %u, but backlink is missing\n", id, n_id);
+#endif
all_sym = false;
}
}
diff --git a/eval/src/tests/ann/xp-lsh-nns.cpp b/eval/src/tests/ann/xp-lsh-nns.cpp
index 0ea119a9c70..c028a07a9d7 100644
--- a/eval/src/tests/ann/xp-lsh-nns.cpp
+++ b/eval/src/tests/ann/xp-lsh-nns.cpp
@@ -118,6 +118,7 @@ public:
}
}
std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override;
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override;
V getVector(uint32_t docid) const { return _dva.get(docid); }
double uniformRnd() { return _rndGen.nextUniform(); }
@@ -196,6 +197,45 @@ public:
};
std::vector<NnsHit>
+RpLshNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ std::vector<NnsHit> result;
+ result.reserve(k);
+
+ std::vector<float> tmp(_numDims);
+ vespalib::ArrayRef<float> tmpArr(tmp);
+
+ LsMaskHash query_hash = mask_hash_from_pv(vector, _transformationMatrix);
+ LshHitHeap heap(std::max(k, search_k));
+ int limit_hash_dist = 99999;
+ int skipCnt = 0;
+ int fullCnt = 0;
+ int whdcCnt = 0;
+ size_t docidLimit = _generated_doc_hashes.size();
+ for (uint32_t docid = 0; docid < docidLimit; ++docid) {
+ if (blacklist.isSet(docid)) continue;
+ int hd = hash_dist(query_hash, _generated_doc_hashes[docid]);
+ if (hd <= limit_hash_dist) {
+ ++fullCnt;
+ double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid), tmpArr);
+ LshHit h(docid, dist, hd);
+ if (heap.maybe_use(h)) {
+ ++whdcCnt;
+ limit_hash_dist = heap.limitHashDistance();
+ }
+ } else {
+ ++skipCnt;
+ }
+ }
+ std::vector<LshHit> best = heap.bestLshHits();
+ size_t numHits = std::min((size_t)k, best.size());
+ for (size_t i = 0; i < numHits; ++i) {
+ result.emplace_back(best[i].docid, SqDist(best[i].distance));
+ }
+ return result;
+}
+
+std::vector<NnsHit>
RpLshNns::topK(uint32_t k, Vector vector, uint32_t search_k)
{
std::vector<NnsHit> result;