aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/nns.h
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-02-12 10:30:39 +0000
committerArne Juul <arnej@verizonmedia.com>2020-02-24 12:22:33 +0000
commitcc3c709d6278ebd699d4f4c67f8f769c9b6fa177 (patch)
treedb8125a746a93250fe405b0a62828cc6558a5ef1 /eval/src/tests/ann/nns.h
parentffa2293de302d99051f7fc97d29c4dc606f045f1 (diff)
add and verify filter option
split out common subroutines
Diffstat (limited to 'eval/src/tests/ann/nns.h')
-rw-r--r--eval/src/tests/ann/nns.h26
1 files changed, 26 insertions, 0 deletions
diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h
index ffe2882188e..ef3e4b5d69c 100644
--- a/eval/src/tests/ann/nns.h
+++ b/eval/src/tests/ann/nns.h
@@ -37,6 +37,31 @@ struct NnsHitComparatorLessDocid {
}
};
+class BitVector {
+private:
+ std::vector<uint64_t> _bits;
+public:
+ BitVector(size_t sz) : _bits((sz+63)/64) {}
+ BitVector& setBit(size_t idx) {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ _bits[idx/64] |= mask;
+ return *this;
+ }
+ bool isSet(size_t idx) const {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ uint64_t word = _bits[idx/64];
+ return (word & mask) != 0;
+ }
+ BitVector& clearBit(size_t idx) {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ _bits[idx/64] &= ~mask;
+ return *this;
+ }
+};
+
template <typename FltType = float>
class NNS
{
@@ -50,6 +75,7 @@ public:
using Vector = vespalib::ConstArrayRef<FltType>;
virtual std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) = 0;
+ virtual std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) = 0;
virtual ~NNS() {}
protected:
uint32_t _numDims;