aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-hnsw-wrap.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-hnsw-wrap.cpp')
-rw-r--r--eval/src/tests/ann/xp-hnsw-wrap.cpp28
1 files changed, 28 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-hnsw-wrap.cpp b/eval/src/tests/ann/xp-hnsw-wrap.cpp
index 3eb01142dcd..45c7a974254 100644
--- a/eval/src/tests/ann/xp-hnsw-wrap.cpp
+++ b/eval/src/tests/ann/xp-hnsw-wrap.cpp
@@ -46,6 +46,34 @@ public:
}
return result;
}
+
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override {
+ std::vector<NnsHit> reversed;
+ uint32_t adjusted_k = k+4;
+ uint32_t adjusted_sk = search_k+4;
+ for (int retry = 0; (retry < 5) && (reversed.size() < k); ++retry) {
+ reversed.clear();
+ _hnsw.setEf(adjusted_sk);
+ auto priQ = _hnsw.searchKnn(vector.cbegin(), adjusted_k);
+ while (! priQ.empty()) {
+ auto pair = priQ.top();
+ if (! blacklist.isSet(pair.second)) {
+ reversed.emplace_back(pair.second, SqDist(pair.first));
+ }
+ priQ.pop();
+ }
+ double got = 1 + reversed.size();
+ double factor = 1.25 * k / got;
+ adjusted_k *= factor;
+ adjusted_sk *= factor;
+ }
+ std::vector<NnsHit> result;
+ while (result.size() < k && !reversed.empty()) {
+ result.push_back(reversed.back());
+ reversed.pop_back();
+ }
+ return result;
+ }
};
std::unique_ptr<NNS<float>>