aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-lsh-nns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-lsh-nns.cpp')
-rw-r--r--eval/src/tests/ann/xp-lsh-nns.cpp40
1 files changed, 40 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-lsh-nns.cpp b/eval/src/tests/ann/xp-lsh-nns.cpp
index 0ea119a9c70..c028a07a9d7 100644
--- a/eval/src/tests/ann/xp-lsh-nns.cpp
+++ b/eval/src/tests/ann/xp-lsh-nns.cpp
@@ -118,6 +118,7 @@ public:
}
}
std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override;
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override;
V getVector(uint32_t docid) const { return _dva.get(docid); }
double uniformRnd() { return _rndGen.nextUniform(); }
@@ -196,6 +197,45 @@ public:
};
std::vector<NnsHit>
+RpLshNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ std::vector<NnsHit> result;
+ result.reserve(k);
+
+ std::vector<float> tmp(_numDims);
+ vespalib::ArrayRef<float> tmpArr(tmp);
+
+ LsMaskHash query_hash = mask_hash_from_pv(vector, _transformationMatrix);
+ LshHitHeap heap(std::max(k, search_k));
+ int limit_hash_dist = 99999;
+ int skipCnt = 0;
+ int fullCnt = 0;
+ int whdcCnt = 0;
+ size_t docidLimit = _generated_doc_hashes.size();
+ for (uint32_t docid = 0; docid < docidLimit; ++docid) {
+ if (blacklist.isSet(docid)) continue;
+ int hd = hash_dist(query_hash, _generated_doc_hashes[docid]);
+ if (hd <= limit_hash_dist) {
+ ++fullCnt;
+ double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid), tmpArr);
+ LshHit h(docid, dist, hd);
+ if (heap.maybe_use(h)) {
+ ++whdcCnt;
+ limit_hash_dist = heap.limitHashDistance();
+ }
+ } else {
+ ++skipCnt;
+ }
+ }
+ std::vector<LshHit> best = heap.bestLshHits();
+ size_t numHits = std::min((size_t)k, best.size());
+ for (size_t i = 0; i < numHits; ++i) {
+ result.emplace_back(best[i].docid, SqDist(best[i].distance));
+ }
+ return result;
+}
+
+std::vector<NnsHit>
RpLshNns::topK(uint32_t k, Vector vector, uint32_t search_k)
{
std::vector<NnsHit> result;