summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-hnsw-wrap.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-hnsw-wrap.cpp')
-rw-r--r--eval/src/tests/ann/xp-hnsw-wrap.cpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-hnsw-wrap.cpp b/eval/src/tests/ann/xp-hnsw-wrap.cpp
new file mode 100644
index 00000000000..30202fde534
--- /dev/null
+++ b/eval/src/tests/ann/xp-hnsw-wrap.cpp
@@ -0,0 +1,54 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "nns.h"
+#include <iostream>
+#include "/git/hnswlib/hnswlib/hnswlib.h"
+
+class HnswWrapNns : public NNS<float>
+{
+private:
+ using Implementation = hnswlib::HierarchicalNSW<float>;
+ hnswlib::L2Space _l2space;
+ Implementation _hnsw;
+
+public:
+ HnswWrapNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva),
+ _l2space(numDims),
+ _hnsw(&_l2space, 1000000, 16, 200)
+ {
+ }
+
+ ~HnswWrapNns() {}
+
+ void addDoc(uint32_t docid) override {
+ Vector vector = _dva.get(docid);
+ _hnsw.addPoint(vector.cbegin(), docid);
+ }
+
+ void removeDoc(uint32_t ) override {
+ }
+
+ std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override {
+ std::vector<NnsHit> reversed;
+ auto priQ = _hnsw.searchKnn(vector.cbegin(), std::max(k, search_k));
+ while (! priQ.empty()) {
+ auto pair = priQ.top();
+ reversed.emplace_back(pair.second, SqDist(pair.first));
+ priQ.pop();
+ }
+ std::vector<NnsHit> result;
+ while (result.size() < k && !reversed.empty()) {
+ result.push_back(reversed.back());
+ reversed.pop_back();
+ }
+ return result;
+ }
+};
+
+std::unique_ptr<NNS<float>>
+make_hnsw_wrap(uint32_t numDims, const DocVectorAccess<float> &dva)
+{
+ NNS<float> *p = new HnswWrapNns(numDims, dva);
+ return std::unique_ptr<NNS<float>>(p);
+}