diff options
Diffstat (limited to 'eval/src/tests/ann/nns.h')
-rw-r--r-- | eval/src/tests/ann/nns.h | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h index ffe2882188e..ef3e4b5d69c 100644 --- a/eval/src/tests/ann/nns.h +++ b/eval/src/tests/ann/nns.h @@ -37,6 +37,31 @@ struct NnsHitComparatorLessDocid { } }; +class BitVector { +private: + std::vector<uint64_t> _bits; +public: + BitVector(size_t sz) : _bits((sz+63)/64) {} + BitVector& setBit(size_t idx) { + uint64_t mask = 1; + mask <<= (idx%64); + _bits[idx/64] |= mask; + return *this; + } + bool isSet(size_t idx) const { + uint64_t mask = 1; + mask <<= (idx%64); + uint64_t word = _bits[idx/64]; + return (word & mask) != 0; + } + BitVector& clearBit(size_t idx) { + uint64_t mask = 1; + mask <<= (idx%64); + _bits[idx/64] &= ~mask; + return *this; + } +}; + template <typename FltType = float> class NNS { @@ -50,6 +75,7 @@ public: using Vector = vespalib::ConstArrayRef<FltType>; virtual std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) = 0; + virtual std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) = 0; virtual ~NNS() {} protected: uint32_t _numDims; |