aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/nns.h
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/nns.h')
-rw-r--r--eval/src/tests/ann/nns.h26
1 files changed, 26 insertions, 0 deletions
diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h
index ffe2882188e..ef3e4b5d69c 100644
--- a/eval/src/tests/ann/nns.h
+++ b/eval/src/tests/ann/nns.h
@@ -37,6 +37,31 @@ struct NnsHitComparatorLessDocid {
}
};
+class BitVector {
+private:
+ std::vector<uint64_t> _bits;
+public:
+ BitVector(size_t sz) : _bits((sz+63)/64) {}
+ BitVector& setBit(size_t idx) {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ _bits[idx/64] |= mask;
+ return *this;
+ }
+ bool isSet(size_t idx) const {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ uint64_t word = _bits[idx/64];
+ return (word & mask) != 0;
+ }
+ BitVector& clearBit(size_t idx) {
+ uint64_t mask = 1;
+ mask <<= (idx%64);
+ _bits[idx/64] &= ~mask;
+ return *this;
+ }
+};
+
template <typename FltType = float>
class NNS
{
@@ -50,6 +75,7 @@ public:
using Vector = vespalib::ConstArrayRef<FltType>;
virtual std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) = 0;
+ virtual std::vector<NnsHit> topKfilter(uint32_t k, Vector vector, uint32_t search_k, const BitVector &blacklist) = 0;
virtual ~NNS() {}
protected:
uint32_t _numDims;