// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once #include #include "doc_vector_access.h" #include "nns-l2.h" #include struct SqDist { double distance; explicit SqDist(double d) : distance(d) {} }; struct NnsHit { uint32_t docid; SqDist sq; NnsHit(uint32_t di, SqDist sqD) : docid(di), sq(sqD) {} }; struct NnsHitComparatorLessDistance { bool operator() (const NnsHit &lhs, const NnsHit& rhs) const { if (lhs.sq.distance > rhs.sq.distance) return false; if (lhs.sq.distance < rhs.sq.distance) return true; return (lhs.docid > rhs.docid); } }; struct NnsHitComparatorGreaterDistance { bool operator() (const NnsHit &lhs, const NnsHit& rhs) const { if (lhs.sq.distance < rhs.sq.distance) return false; if (lhs.sq.distance > rhs.sq.distance) return true; return (lhs.docid > rhs.docid); } }; struct NnsHitComparatorLessDocid { bool operator() (const NnsHit &lhs, const NnsHit& rhs) const { return (lhs.docid < rhs.docid); } }; template class NNS { public: NNS(uint32_t numDims, const DocVectorAccess &dva) : _numDims(numDims), _dva(dva) {} virtual void addDoc(uint32_t docid) = 0; virtual void removeDoc(uint32_t docid) = 0; using Vector = vespalib::ConstArrayRef; virtual std::vector topK(uint32_t k, Vector vector, uint32_t search_k) = 0; virtual ~NNS() {} protected: uint32_t _numDims; const DocVectorAccess &_dva; }; extern std::unique_ptr> make_annoy_nns(uint32_t numDims, const DocVectorAccess &dva); extern std::unique_ptr> make_rplsh_nns(uint32_t numDims, const DocVectorAccess &dva); extern std::unique_ptr> make_hnsw_nns(uint32_t numDims, const DocVectorAccess &dva);