aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-lsh-nns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-lsh-nns.cpp')
-rw-r--r--eval/src/tests/ann/xp-lsh-nns.cpp243
1 files changed, 243 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-lsh-nns.cpp b/eval/src/tests/ann/xp-lsh-nns.cpp
new file mode 100644
index 00000000000..285985167c0
--- /dev/null
+++ b/eval/src/tests/ann/xp-lsh-nns.cpp
@@ -0,0 +1,243 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "nns.h"
+#include "std-random.h"
+#include <assert.h>
+#include <string.h>
+#include <algorithm>
+#include <queue>
+#include <set>
+#include <vespa/vespalib/util/priority_queue.h>
+
+using V = vespalib::ConstArrayRef<float>;
+
+#define NUM_HASH_WORDS 4
+#define IGNORE_BITS 32
+#define HIST_SIZE (64*NUM_HASH_WORDS + 1)
+
+struct LsMaskHash {
+ uint64_t bits[NUM_HASH_WORDS];
+ uint64_t mask[NUM_HASH_WORDS];
+ LsMaskHash() {
+ memset(bits, 0xff, sizeof bits);
+ memset(mask, 0xff, sizeof mask);
+ }
+};
+
+static inline int hash_dist(const LsMaskHash &h1, const LsMaskHash &h2) {
+ int cnt = 0;
+ for (size_t o = 0; o < NUM_HASH_WORDS; ++o) {
+ uint64_t hx = h1.bits[o] ^ h2.bits[o];
+ hx &= (h1.mask[o] | h2.mask[o]);
+ cnt += __builtin_popcountl(hx);
+ }
+ return cnt;
+}
+
+struct Multiplier {
+ std::vector<float> multiplier;
+ Multiplier(size_t dims) : multiplier(dims, 0.0) {}
+};
+
+LsMaskHash mask_hash_from_pv(V p, std::vector<Multiplier> rpMatrix) {
+ LsMaskHash result;
+ float transformed[NUM_HASH_WORDS][64];
+ std::vector<double> squares;
+ for (size_t o = 0; o < NUM_HASH_WORDS; ++o) {
+ uint64_t hash = 0;
+ for (size_t bit = 0; bit < 64; ++bit) {
+ hash <<= 1u;
+ V m = rpMatrix[bit+64*o].multiplier;
+ double dotproduct = l2distCalc.product(m, p);
+ if (dotproduct > 0.0) {
+ hash |= 1u;
+ }
+ double sq = dotproduct * dotproduct;
+ transformed[o][bit] = sq;
+ squares.push_back(sq);
+ }
+ result.bits[o] = hash;
+ }
+ std::sort(squares.begin(), squares.end());
+ double lim = squares[IGNORE_BITS*NUM_HASH_WORDS-1];
+ for (size_t o = 0; o < NUM_HASH_WORDS; ++o) {
+ uint64_t mask = 0;
+ for (size_t bit = 0; bit < 64; ++bit) {
+ mask <<= 1u;
+ if (transformed[o][bit] > lim) {
+ mask |= 1u;
+ }
+ }
+ result.mask[o] = mask;
+ }
+ return result;
+}
+
+class RpLshNns : public NNS<float>
+{
+private:
+ RndGen _rndGen;
+ std::vector<Multiplier> _transformationMatrix;
+ std::vector<LsMaskHash> _generated_doc_hashes;
+
+public:
+ RpLshNns(uint32_t numDims, const DocVectorAccess<float> &dva)
+ : NNS(numDims, dva), _rndGen()
+ {
+ _transformationMatrix.reserve(NUM_HASH_WORDS*64);
+ for (size_t i = 0; i < NUM_HASH_WORDS*64; i++) {
+ _transformationMatrix.emplace_back(numDims);
+ Multiplier &mult = _transformationMatrix.back();
+ for (float &v : mult.multiplier) {
+ v = _rndGen.nextNormal();
+ }
+ }
+ fprintf(stderr, "ignore bits for lsh: %d*%d=%d\n",
+ IGNORE_BITS, NUM_HASH_WORDS, IGNORE_BITS*NUM_HASH_WORDS);
+ _generated_doc_hashes.reserve(100000);
+ }
+
+ ~RpLshNns() {
+ }
+
+ void addDoc(uint32_t docid) override {
+ V vector = _dva.get(docid);
+ LsMaskHash hash = mask_hash_from_pv(vector, _transformationMatrix);
+ if (_generated_doc_hashes.size() == docid) {
+ _generated_doc_hashes.push_back(hash);
+ return;
+ }
+ while (_generated_doc_hashes.size() <= docid) {
+ _generated_doc_hashes.push_back(LsMaskHash());
+ }
+ _generated_doc_hashes[docid] = hash;
+ }
+ void removeDoc(uint32_t docid) override {
+ if (_generated_doc_hashes.size() > docid) {
+ _generated_doc_hashes[docid] = LsMaskHash();
+ }
+ }
+ std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override;
+
+ V getVector(uint32_t docid) const { return _dva.get(docid); }
+ double uniformRnd() { return _rndGen.nextUniform(); }
+ uint32_t dims() const { return _numDims; }
+};
+
+
+struct LshHit {
+ double distance;
+ uint32_t docid;
+ int hash_distance;
+ LshHit() : distance(0.0), docid(0u), hash_distance(0) {}
+ LshHit(int id, double dist, int hd = 0)
+ : distance(dist), docid(id), hash_distance(hd) {}
+};
+
+struct LshHitComparator {
+ bool operator() (const LshHit &lhs, const LshHit& rhs) const {
+ if (lhs.distance < rhs.distance) return false;
+ if (lhs.distance > rhs.distance) return true;
+ return (lhs.docid > rhs.docid);
+ }
+};
+
+class LshHitHeap {
+private:
+ size_t _size;
+ vespalib::PriorityQueue<LshHit, LshHitComparator> _priQ;
+ std::vector<int> hd_histogram;
+public:
+ explicit LshHitHeap(size_t maxSize) : _size(maxSize), _priQ() {
+ _priQ.reserve(maxSize);
+ }
+ ~LshHitHeap() {}
+ bool maybe_use(const LshHit &hit) {
+ if (_priQ.size() < _size) {
+ _priQ.push(hit);
+ uint32_t newHd = hit.hash_distance;
+ while (hd_histogram.size() <= newHd) {
+ hd_histogram.push_back(0);
+ }
+ hd_histogram[newHd]++;
+ } else if (hit.distance < _priQ.front().distance) {
+ uint32_t oldHd = _priQ.front().hash_distance;
+ uint32_t newHd = hit.hash_distance;
+ while (hd_histogram.size() <= newHd) {
+ hd_histogram.push_back(0);
+ }
+ hd_histogram[newHd]++;
+ hd_histogram[oldHd]--;
+ _priQ.front() = hit;
+ _priQ.adjust();
+ return true;
+ }
+ return false;
+ }
+ int limitHashDistance() {
+ size_t sz = _priQ.size();
+ uint32_t sum = 0;
+ for (uint32_t i = 0; i < hd_histogram.size(); ++i) {
+ sum += hd_histogram[i];
+ if (sum >= ((3*sz)/4)) return i;
+ }
+ return 99999;
+ }
+ std::vector<LshHit> bestLshHits() {
+ std::vector<LshHit> result;
+ size_t sz = _priQ.size();
+ result.resize(sz);
+ for (size_t i = sz; i-- > 0; ) {
+ result[i] = _priQ.front();
+ _priQ.pop_front();
+ }
+ return result;
+ }
+};
+
+std::vector<NnsHit>
+RpLshNns::topK(uint32_t k, Vector vector, uint32_t search_k)
+{
+ std::vector<NnsHit> result;
+ result.reserve(k);
+
+ std::vector<float> tmp(_numDims);
+ vespalib::ArrayRef<float> tmpArr(tmp);
+
+ LsMaskHash query_hash = mask_hash_from_pv(vector, _transformationMatrix);
+ LshHitHeap heap(std::max(k, search_k));
+ int limit_hash_dist = 99999;
+ int histogram[HIST_SIZE];
+ int skipCnt = 0;
+ int fullCnt = 0;
+ int whdcCnt = 0;
+ memset(histogram, 0, sizeof histogram);
+ size_t docidLimit = _generated_doc_hashes.size();
+ for (uint32_t docid = 0; docid < docidLimit; ++docid) {
+ int hd = hash_dist(query_hash, _generated_doc_hashes[docid]);
+ histogram[hd]++;
+ if (hd <= limit_hash_dist) {
+ ++fullCnt;
+ double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid), tmpArr);
+ LshHit h(docid, dist, hd);
+ if (heap.maybe_use(h)) {
+ ++whdcCnt;
+ limit_hash_dist = heap.limitHashDistance();
+ }
+ } else {
+ ++skipCnt;
+ }
+ }
+ std::vector<LshHit> best = heap.bestLshHits();
+ size_t numHits = std::min((size_t)k, best.size());
+ for (size_t i = 0; i < numHits; ++i) {
+ result.emplace_back(best[i].docid, best[i].distance);
+ }
+ return result;
+}
+
+std::unique_ptr<NNS<float>>
+make_rplsh_nns(uint32_t numDims, const DocVectorAccess<float> &dva)
+{
+ return std::make_unique<RpLshNns>(numDims, dva);
+}