diff options
Diffstat (limited to 'eval/src/tests/ann/xp-lsh-nns.cpp')
-rw-r--r-- | eval/src/tests/ann/xp-lsh-nns.cpp | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-lsh-nns.cpp b/eval/src/tests/ann/xp-lsh-nns.cpp index 0ea119a9c70..c028a07a9d7 100644 --- a/eval/src/tests/ann/xp-lsh-nns.cpp +++ b/eval/src/tests/ann/xp-lsh-nns.cpp @@ -118,6 +118,7 @@ public: } } std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override; + std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override; V getVector(uint32_t docid) const { return _dva.get(docid); } double uniformRnd() { return _rndGen.nextUniform(); } @@ -196,6 +197,45 @@ public: }; std::vector<NnsHit> +RpLshNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + std::vector<NnsHit> result; + result.reserve(k); + + std::vector<float> tmp(_numDims); + vespalib::ArrayRef<float> tmpArr(tmp); + + LsMaskHash query_hash = mask_hash_from_pv(vector, _transformationMatrix); + LshHitHeap heap(std::max(k, search_k)); + int limit_hash_dist = 99999; + int skipCnt = 0; + int fullCnt = 0; + int whdcCnt = 0; + size_t docidLimit = _generated_doc_hashes.size(); + for (uint32_t docid = 0; docid < docidLimit; ++docid) { + if (blacklist.isSet(docid)) continue; + int hd = hash_dist(query_hash, _generated_doc_hashes[docid]); + if (hd <= limit_hash_dist) { + ++fullCnt; + double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid), tmpArr); + LshHit h(docid, dist, hd); + if (heap.maybe_use(h)) { + ++whdcCnt; + limit_hash_dist = heap.limitHashDistance(); + } + } else { + ++skipCnt; + } + } + std::vector<LshHit> best = heap.bestLshHits(); + size_t numHits = std::min((size_t)k, best.size()); + for (size_t i = 0; i < numHits; ++i) { + result.emplace_back(best[i].docid, SqDist(best[i].distance)); + } + return result; +} + +std::vector<NnsHit> RpLshNns::topK(uint32_t k, Vector vector, uint32_t search_k) { std::vector<NnsHit> result; |