aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/nns.h
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/nns.h')
-rw-r--r--eval/src/tests/ann/nns.h60
1 files changed, 60 insertions, 0 deletions
diff --git a/eval/src/tests/ann/nns.h b/eval/src/tests/ann/nns.h
new file mode 100644
index 00000000000..2e6666309bd
--- /dev/null
+++ b/eval/src/tests/ann/nns.h
@@ -0,0 +1,60 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+#include <vespa/vespalib/util/arrayref.h>
+#include "doc_vector_access.h"
+#include "nns-l2.h"
+#include <memory>
+
+struct NnsHit {
+ uint32_t docid;
+ double sqDistance;
+ NnsHit(uint32_t di, double sqD)
+ : docid(di), sqDistance(sqD) {}
+};
+struct NnsHitComparatorLessDistance {
+ bool operator() (const NnsHit &lhs, const NnsHit& rhs) const {
+ if (lhs.sqDistance > rhs.sqDistance) return false;
+ if (lhs.sqDistance < rhs.sqDistance) return true;
+ return (lhs.docid > rhs.docid);
+ }
+};
+struct NnsHitComparatorGreaterDistance {
+ bool operator() (const NnsHit &lhs, const NnsHit& rhs) const {
+ if (lhs.sqDistance < rhs.sqDistance) return false;
+ if (lhs.sqDistance > rhs.sqDistance) return true;
+ return (lhs.docid > rhs.docid);
+ }
+};
+struct NnsHitComparatorLessDocid {
+ bool operator() (const NnsHit &lhs, const NnsHit& rhs) const {
+ return (lhs.docid < rhs.docid);
+ }
+};
+
+template <typename FltType = float>
+class NNS
+{
+public:
+ NNS(uint32_t numDims, const DocVectorAccess<FltType> &dva)
+ : _numDims(numDims), _dva(dva)
+ {}
+
+ virtual void addDoc(uint32_t docid) = 0;
+ virtual void removeDoc(uint32_t docid) = 0;
+
+ using Vector = vespalib::ConstArrayRef<FltType>;
+ virtual std::vector<NnsHit> topK(uint32_t k, Vector vector, uint32_t search_k) = 0;
+ virtual ~NNS() {}
+protected:
+ uint32_t _numDims;
+ const DocVectorAccess<FltType> &_dva;
+};
+
+extern
+std::unique_ptr<NNS<float>>
+make_annoy_nns(uint32_t numDims, const DocVectorAccess<float> &dva);
+
+extern
+std::unique_ptr<NNS<float>>
+make_rplsh_nns(uint32_t numDims, const DocVectorAccess<float> &dva);