diff options
Diffstat (limited to 'eval/src/tests/ann/xp-annoy-nns.cpp')
-rw-r--r-- | eval/src/tests/ann/xp-annoy-nns.cpp | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-annoy-nns.cpp b/eval/src/tests/ann/xp-annoy-nns.cpp index f022aae5974..213e583d95a 100644 --- a/eval/src/tests/ann/xp-annoy-nns.cpp +++ b/eval/src/tests/ann/xp-annoy-nns.cpp @@ -27,6 +27,7 @@ struct Node { virtual Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) = 0; virtual int remove(uint32_t docid, V vector) = 0; virtual void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const = 0; + virtual void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const = 0; virtual void stats(std::vector<uint32_t> &depths) = 0; }; @@ -38,6 +39,7 @@ struct LeafNode : public Node { Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override; int remove(uint32_t docid, V vector) override; void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const override; + void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override; Node *split(AnnoyLikeNns &meta); virtual void stats(std::vector<uint32_t> &depths) override { depths.push_back(1); } @@ -55,6 +57,7 @@ struct SplitNode : public Node { Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override; int remove(uint32_t docid, V vector) override; void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const override; + void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override; double planeDistance(V vector) const; virtual void stats(std::vector<uint32_t> &depths) override { @@ -106,6 +109,8 @@ public: } std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override; + std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override; + V getVector(uint32_t docid) const { return _dva.get(docid); } double uniformRnd() { return _rndGen.nextUniform(); } uint32_t dims() const { return _numDims; } @@ -304,6 +309,16 @@ LeafNode::findCandidates(std::set<uint32_t> &cands, V, NodeQueue &, double) cons } } +void +LeafNode::filterCandidates(std::set<uint32_t> &cands, V, NodeQueue &, double, const BitVector &blacklist) const +{ + for (uint32_t d : docids) { + if (blacklist.isSet(d)) continue; + cands.insert(d); + } +} + + SplitNode::~SplitNode() { delete leftChildren; @@ -344,6 +359,15 @@ SplitNode::findCandidates(std::set<uint32_t> &, V vector, NodeQueue &queue, doub queue.push(std::make_pair(std::min(d, minDist), rightChildren)); } +void +SplitNode::filterCandidates(std::set<uint32_t> &, V vector, NodeQueue &queue, double minDist, const BitVector &) const +{ + double d = planeDistance(vector); + // fprintf(stderr, "push 2 nodes dist %g\n", d); + queue.push(std::make_pair(std::min(-d, minDist), leftChildren)); + queue.push(std::make_pair(std::min(d, minDist), rightChildren)); +} + std::vector<NnsHit> AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) { @@ -387,6 +411,40 @@ AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k) return r; } +std::vector<NnsHit> +AnnoyLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) +{ + ++find_top_k_cnt; + std::vector<NnsHit> r; + r.reserve(k); + std::set<uint32_t> candidates; + NodeQueue queue; + for (Node *root : _roots) { + double dist = std::numeric_limits<double>::max(); + queue.push(std::make_pair(dist, root)); + } + while ((candidates.size() < std::max(k, search_k)) && (queue.size() > 0)) { + const QueueNode& top = queue.top(); + double md = top.first; + // fprintf(stderr, "find candidates: node with min distance %g\n", md); + Node *n = top.second; + queue.pop(); + n->filterCandidates(candidates, vector, queue, md, blacklist); + ++find_cand_cnt; + } + for (uint32_t docid : candidates) { + if (blacklist.isSet(docid)) continue; + double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid)); + NnsHit hit(docid, SqDist(dist)); + r.push_back(hit); + } + std::sort(r.begin(), r.end(), NnsHitComparatorLessDistance()); + while (r.size() > k) r.pop_back(); + return r; +} + + + void AnnoyLikeNns::dumpStats() { fprintf(stderr, "stats for AnnoyLikeNns:\n"); |