From 2ad26b78130087c78651977e593583807a0a582b Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 16 Jan 2020 13:12:29 +0000 Subject: add HNSW algorithm --- eval/src/tests/ann/CMakeLists.txt | 1 + eval/src/tests/ann/sift_benchmark.cpp | 8 + eval/src/tests/ann/xp-hnswlike-nns.cpp | 394 +++++++++++++++++++++++++++++++++ 3 files changed, 403 insertions(+) create mode 100644 eval/src/tests/ann/xp-hnswlike-nns.cpp (limited to 'eval') diff --git a/eval/src/tests/ann/CMakeLists.txt b/eval/src/tests/ann/CMakeLists.txt index d82b2311b22..05256d19f00 100644 --- a/eval/src/tests/ann/CMakeLists.txt +++ b/eval/src/tests/ann/CMakeLists.txt @@ -4,6 +4,7 @@ vespa_add_executable(eval_sift_benchmark_app SOURCES sift_benchmark.cpp xp-annoy-nns.cpp + xp-hnswlike-nns.cpp xp-lsh-nns.cpp DEPENDS vespaeval diff --git a/eval/src/tests/ann/sift_benchmark.cpp b/eval/src/tests/ann/sift_benchmark.cpp index 451d4e1ba50..b0e99b76f8a 100644 --- a/eval/src/tests/ann/sift_benchmark.cpp +++ b/eval/src/tests/ann/sift_benchmark.cpp @@ -276,6 +276,14 @@ TEST("require that Annoy via NNS api mostly works") { } #endif +#if 1 +TEST("require that HNSW via NNS api mostly works") { + DocVectorAdapter adapter; + std::unique_ptr nns = make_hnsw_nns(NUM_DIMS, adapter); + benchmark_nns("HNSW", *nns, { 100, 200 }); +} +#endif + int main(int argc, char **argv) { TEST_MASTER.init(__FILE__); diff --git a/eval/src/tests/ann/xp-hnswlike-nns.cpp b/eval/src/tests/ann/xp-hnswlike-nns.cpp new file mode 100644 index 00000000000..635f586921c --- /dev/null +++ b/eval/src/tests/ann/xp-hnswlike-nns.cpp @@ -0,0 +1,394 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include +#include "nns.h" + +using LinkList = std::vector; + +struct Node { + std::vector _links; + Node(uint32_t , uint32_t numLevels, uint32_t M) + : _links(numLevels) + { + for (uint32_t i = 0; i < _links.size(); ++i) { + _links[i].reserve((i == 0) ? (2 * M + 1) : (M+1)); + } + } +}; + +struct VisitedSet +{ + using Mark = unsigned short; + Mark *ptr; + Mark curval; + size_t sz; + VisitedSet(const VisitedSet &) = delete; + VisitedSet& operator=(const VisitedSet &) = delete; + explicit VisitedSet(size_t size) { + ptr = (Mark *)malloc(size * sizeof(Mark)); + curval = -1; + sz = size; + } + void clear() { + ++curval; + if (curval == 0) { + memset(ptr, 0, sz * sizeof(Mark)); + ++curval; + } + } + ~VisitedSet() { free(ptr); } + void mark(size_t id) { ptr[id] = curval; } + bool isMarked(size_t id) const { return ptr[id] == curval; } +}; + +struct VisitedSetPool +{ + std::unique_ptr lastUsed; + VisitedSetPool() { + lastUsed = std::make_unique(250); + } + ~VisitedSetPool() {} + VisitedSet &get(size_t size) { + if (size > lastUsed->sz) { + lastUsed = std::make_unique(size*2); + } + lastUsed->clear(); + return *lastUsed; + } +}; + +struct HnswHit { + float dist; + uint32_t docid; + HnswHit(uint32_t di, SqDist sq) : dist(sq.distance), docid(di) {} +}; + + +using QueueEntry = HnswHit; +struct GreaterDist { + bool operator() (const QueueEntry &lhs, const QueueEntry& rhs) const { + return (rhs.dist < lhs.dist); + } +}; +struct LesserDist { + bool operator() (const QueueEntry &lhs, const QueueEntry& rhs) const { + return (lhs.dist < rhs.dist); + } +}; + +using NearestList = std::vector; + +struct NearestPriQ : std::priority_queue +{ +}; + +struct FurthestPriQ : std::priority_queue +{ + NearestList steal() { + NearestList result; + c.swap(result); + return result; + } + const NearestList& peek() const { return c; } +}; + +class HnswLikeNns : public NNS +{ +private: + std::vector _nodes; + uint32_t _entryId; + int _entryLevel; + uint32_t _M; + uint32_t _efConstruction; + double _levelMultiplier; + std::default_random_engine _rndGen; + VisitedSetPool _visitedSetPool; + + double distance(Vector v, uint32_t id) const; + + double distance(uint32_t a, uint32_t b) const { + Vector v = _dva.get(a); + return distance(v, b); + } + + int randomLevel() { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(_rndGen)) * _levelMultiplier; + return (int) r; + } + +public: + HnswLikeNns(uint32_t numDims, const DocVectorAccess &dva) + : NNS(numDims, dva), + _nodes(), + _entryId(0), + _entryLevel(-1), + _M(16), + _efConstruction(150), + _levelMultiplier(1.0 / log(1.0 * _M)) + { + _nodes.reserve(1234567); + } + + ~HnswLikeNns() {} + + LinkList& getLinkList(uint32_t docid, uint32_t level) { + // assert(docid < _nodes.size()); + // assert(level < _nodes[docid]._links.size()); + return _nodes[docid]._links[level]; + } + + // simple greedy search + QueueEntry search_layer_simple(Vector vector, QueueEntry curPoint, uint32_t searchLevel) { + bool keepGoing = true; + while (keepGoing) { + keepGoing = false; + const LinkList& neighbors = getLinkList(curPoint.docid, searchLevel); + for (uint32_t n_id : neighbors) { + double dist = distance(vector, n_id); + if (dist < curPoint.dist) { + curPoint = QueueEntry(n_id, SqDist(dist)); + keepGoing = true; + } + } + } + return curPoint; + } + + void search_layer_foradd(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel); + + FurthestPriQ search_layer(Vector vector, NearestList entryPoints, + uint32_t ef, uint32_t searchLevel) { + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + NearestPriQ candidates; + FurthestPriQ w; + for (auto point : entryPoints) { + candidates.push(point); + w.push(point); + visited.mark(point.docid); + } + double limd = std::numeric_limits::max(); + while (! candidates.empty()) { + QueueEntry cand = candidates.top(); + candidates.pop(); + if (cand.dist > limd) { + break; + } + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } + return w; + } + + bool haveCloserDistance(QueueEntry e, const LinkList &r) const { + for (uint32_t prevId : r) { + double dist = distance(e.docid, prevId); + if (dist < e.dist) return true; + } + return false; + } + + LinkList select_neighbors(NearestPriQ &&w, uint32_t curMax) const; + + LinkList select_neighbors(const NearestList &neighbors, uint32_t curMax) { + if (neighbors.size() <= curMax) { + LinkList result; + result.reserve(curMax+1); + for (const auto & entry : neighbors) { + result.push_back(entry.docid); + } + return result; + } + NearestPriQ w; + for (const QueueEntry & entry : neighbors) { + w.push(entry); + } + return select_neighbors(std::move(w), curMax); + } + + void addDoc(uint32_t docid) override { + Vector vector = _dva.get(docid); + for (uint32_t id = _nodes.size(); id <= docid; ++id) { + _nodes.emplace_back(id, 0, _M); + } + int level = randomLevel(); + assert(_nodes[docid]._links.size() == 0); + _nodes[docid] = Node(docid, level+1, _M); + if (_entryLevel < 0) { + _entryId = docid; + _entryLevel = level; + return; + } + int searchLevel = _entryLevel; + double entryDist = distance(vector, _entryId); + QueueEntry entryPoint(_entryId, SqDist(entryDist)); + while (searchLevel > level) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + searchLevel = std::min(level, _entryLevel); + FurthestPriQ w; + w.push(entryPoint); + while (searchLevel >= 0) { + search_layer_foradd(vector, w, _efConstruction, searchLevel); + uint32_t maxLinks = (searchLevel > 0) ? _M : (2 * _M); + LinkList neighbors = select_neighbors(w.peek(), maxLinks); + connect_new_node(docid, neighbors, searchLevel); + each_shrink_ifneeded(neighbors, searchLevel); + --searchLevel; + } + if (level > _entryLevel) { + _entryLevel = level; + _entryId = docid; + } + } + + void connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level); + + void each_shrink_ifneeded(const LinkList &neighbors, uint32_t level); + + void removeDoc(uint32_t ) override { + } + + std::vector topK(uint32_t k, Vector vector, uint32_t search_k) override { + std::vector result; + if (_entryLevel < 0) return result; + double entryDist = distance(vector, _entryId); + QueueEntry entryPoint(_entryId, SqDist(entryDist)); + int searchLevel = _entryLevel; + while (searchLevel > 0) { + entryPoint = search_layer_simple(vector, entryPoint, searchLevel); + --searchLevel; + } + NearestList entryPoints; + entryPoints.push_back(entryPoint); + FurthestPriQ w = search_layer(vector, entryPoints, std::max(k, search_k), 0); + if (w.size() < k) { + fprintf(stderr, "fewer than expected hits: k=%u, ks=%u, got=%zu\n", + k, search_k, w.size()); + } + while (w.size() > k) { + w.pop(); + } + std::vector reversed; + reversed.reserve(w.size()); + while (! w.empty()) { + reversed.push_back(w.top()); + w.pop(); + } + result.reserve(reversed.size()); + while (! reversed.empty()) { + const QueueEntry &hit = reversed.back(); + result.emplace_back(hit.docid, SqDist(hit.dist)); + reversed.pop_back(); + } + return result; + } +}; + +double +HnswLikeNns::distance(Vector v, uint32_t b) const +{ + Vector w = _dva.get(b); + return l2distCalc.l2sq_dist(v, w); +} + +void +HnswLikeNns::each_shrink_ifneeded(const LinkList &neighbors, uint32_t level) { + uint32_t maxLinks = (level > 0) ? _M : (2 * _M); + for (uint32_t old_id : neighbors) { + LinkList &oldLinks = getLinkList(old_id, level); + if (oldLinks.size() > maxLinks) { + NearestPriQ w; + for (uint32_t n_id : oldLinks) { + double n_dist = distance(old_id, n_id); + w.emplace(n_id, SqDist(n_dist)); + } + oldLinks = select_neighbors(std::move(w), maxLinks); + } + } +} + +void +HnswLikeNns::search_layer_foradd(Vector vector, FurthestPriQ &w, + uint32_t ef, uint32_t searchLevel) +{ + NearestPriQ candidates; + VisitedSet &visited = _visitedSetPool.get(_nodes.size()); + + for (const QueueEntry& entry : w.peek()) { + candidates.push(entry); + visited.mark(entry.docid); + } + + double limd = std::numeric_limits::max(); + while (! candidates.empty()) { + QueueEntry cand = candidates.top(); + candidates.pop(); + if (cand.dist > limd) { + break; + } + for (uint32_t e_id : getLinkList(cand.docid, searchLevel)) { + if (visited.isMarked(e_id)) continue; + visited.mark(e_id); + double e_dist = distance(vector, e_id); + if (e_dist < limd) { + candidates.emplace(e_id, SqDist(e_dist)); + w.emplace(e_id, SqDist(e_dist)); + if (w.size() > ef) { + w.pop(); + limd = w.top().dist; + } + } + } + } + return; +} + +LinkList +HnswLikeNns::select_neighbors(NearestPriQ &&w, uint32_t curMax) const { + LinkList result; + result.reserve(curMax+1); + while (! w.empty()) { + QueueEntry e = w.top(); + w.pop(); + if (haveCloserDistance(e, result)) continue; + result.push_back(e.docid); + if (result.size() >= curMax) break; + } + return result; +} + +void +HnswLikeNns::connect_new_node(uint32_t id, const LinkList &neighbors, uint32_t level) { + LinkList &newLinks = getLinkList(id, level); + for (uint32_t neigh_id : neighbors) { + LinkList &oldLinks = getLinkList(neigh_id, level); + newLinks.push_back(neigh_id); + oldLinks.push_back(id); + } +} + + +std::unique_ptr> +make_hnsw_nns(uint32_t numDims, const DocVectorAccess &dva) +{ + return std::make_unique(numDims, dva); +} + + -- cgit v1.2.3