aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/xp-annoy-nns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/xp-annoy-nns.cpp')
-rw-r--r--eval/src/tests/ann/xp-annoy-nns.cpp58
1 files changed, 58 insertions, 0 deletions
diff --git a/eval/src/tests/ann/xp-annoy-nns.cpp b/eval/src/tests/ann/xp-annoy-nns.cpp
index f022aae5974..213e583d95a 100644
--- a/eval/src/tests/ann/xp-annoy-nns.cpp
+++ b/eval/src/tests/ann/xp-annoy-nns.cpp
@@ -27,6 +27,7 @@ struct Node {
virtual Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) = 0;
virtual int remove(uint32_t docid, V vector) = 0;
virtual void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const = 0;
+ virtual void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const = 0;
virtual void stats(std::vector<uint32_t> &depths) = 0;
};
@@ -38,6 +39,7 @@ struct LeafNode : public Node {
Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override;
int remove(uint32_t docid, V vector) override;
void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const override;
+ void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override;
Node *split(AnnoyLikeNns &meta);
virtual void stats(std::vector<uint32_t> &depths) override { depths.push_back(1); }
@@ -55,6 +57,7 @@ struct SplitNode : public Node {
Node *addDoc(uint32_t docid, V vector, AnnoyLikeNns &meta) override;
int remove(uint32_t docid, V vector) override;
void findCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist) const override;
+ void filterCandidates(std::set<uint32_t> &cands, V vector, NodeQueue &queue, double minDist, const BitVector &blacklist) const override;
double planeDistance(V vector) const;
virtual void stats(std::vector<uint32_t> &depths) override {
@@ -106,6 +109,8 @@ public:
}
std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) override;
+ std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &bitvector) override;
+
V getVector(uint32_t docid) const { return _dva.get(docid); }
double uniformRnd() { return _rndGen.nextUniform(); }
uint32_t dims() const { return _numDims; }
@@ -304,6 +309,16 @@ LeafNode::findCandidates(std::set<uint32_t> &cands, V, NodeQueue &, double) cons
}
}
+void
+LeafNode::filterCandidates(std::set<uint32_t> &cands, V, NodeQueue &, double, const BitVector &blacklist) const
+{
+ for (uint32_t d : docids) {
+ if (blacklist.isSet(d)) continue;
+ cands.insert(d);
+ }
+}
+
+
SplitNode::~SplitNode()
{
delete leftChildren;
@@ -344,6 +359,15 @@ SplitNode::findCandidates(std::set<uint32_t> &, V vector, NodeQueue &queue, doub
queue.push(std::make_pair(std::min(d, minDist), rightChildren));
}
+void
+SplitNode::filterCandidates(std::set<uint32_t> &, V vector, NodeQueue &queue, double minDist, const BitVector &) const
+{
+ double d = planeDistance(vector);
+ // fprintf(stderr, "push 2 nodes dist %g\n", d);
+ queue.push(std::make_pair(std::min(-d, minDist), leftChildren));
+ queue.push(std::make_pair(std::min(d, minDist), rightChildren));
+}
+
std::vector<NnsHit>
AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k)
{
@@ -387,6 +411,40 @@ AnnoyLikeNns::topK(uint32_t k, Vector vector, uint32_t search_k)
return r;
}
+std::vector<NnsHit>
+AnnoyLikeNns::topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist)
+{
+ ++find_top_k_cnt;
+ std::vector<NnsHit> r;
+ r.reserve(k);
+ std::set<uint32_t> candidates;
+ NodeQueue queue;
+ for (Node *root : _roots) {
+ double dist = std::numeric_limits<double>::max();
+ queue.push(std::make_pair(dist, root));
+ }
+ while ((candidates.size() < std::max(k, search_k)) && (queue.size() > 0)) {
+ const QueueNode& top = queue.top();
+ double md = top.first;
+ // fprintf(stderr, "find candidates: node with min distance %g\n", md);
+ Node *n = top.second;
+ queue.pop();
+ n->filterCandidates(candidates, vector, queue, md, blacklist);
+ ++find_cand_cnt;
+ }
+ for (uint32_t docid : candidates) {
+ if (blacklist.isSet(docid)) continue;
+ double dist = l2distCalc.l2sq_dist(vector, _dva.get(docid));
+ NnsHit hit(docid, SqDist(dist));
+ r.push_back(hit);
+ }
+ std::sort(r.begin(), r.end(), NnsHitComparatorLessDistance());
+ while (r.size() > k) r.pop_back();
+ return r;
+}
+
+
+
void
AnnoyLikeNns::dumpStats() {
fprintf(stderr, "stats for AnnoyLikeNns:\n");