diff options
Diffstat (limited to 'eval/src/tests/ann/xp-hnsw-wrap.cpp')
-rw-r--r-- | eval/src/tests/ann/xp-hnsw-wrap.cpp | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-hnsw-wrap.cpp b/eval/src/tests/ann/xp-hnsw-wrap.cpp index 3eb01142dcd..45c7a974254 100644 --- a/eval/src/tests/ann/xp-hnsw-wrap.cpp +++ b/eval/src/tests/ann/xp-hnsw-wrap.cpp @@ -46,6 +46,34 @@ public: } return result; } + + std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) override { + std::vector<NnsHit> reversed; + uint32_t adjusted_k = k+4; + uint32_t adjusted_sk = search_k+4; + for (int retry = 0; (retry < 5) && (reversed.size() < k); ++retry) { + reversed.clear(); + _hnsw.setEf(adjusted_sk); + auto priQ = _hnsw.searchKnn(vector.cbegin(), adjusted_k); + while (! priQ.empty()) { + auto pair = priQ.top(); + if (! blacklist.isSet(pair.second)) { + reversed.emplace_back(pair.second, SqDist(pair.first)); + } + priQ.pop(); + } + double got = 1 + reversed.size(); + double factor = 1.25 * k / got; + adjusted_k *= factor; + adjusted_sk *= factor; + } + std::vector<NnsHit> result; + while (result.size() < k && !reversed.empty()) { + result.push_back(reversed.back()); + reversed.pop_back(); + } + return result; + } }; std::unique_ptr<NNS<float>> |