summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-02-17 14:07:24 +0000
committerArne Juul <arnej@verizonmedia.com>2020-02-17 14:07:24 +0000
commit1ab187d4152775a22d1b554fd2364ec5058cdfee (patch)
tree220c359f1dc244174748d041756d56655b814c77 /searchlib/src
parentf4d3589c6b00c330246fa486afc5fe20f6a0f9fe (diff)
add search API in NearestNeighborIndex
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp17
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h3
4 files changed, 22 insertions, 3 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 1965d58c0e7..805012d224c 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -87,7 +87,7 @@ public:
EXPECT_EQ(exp_levels, act_node.levels());
}
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
- auto rv = index->find_top_k(vectors.get_vector(docid), 3);
+ auto rv = index->top_k_candidates(vectors.get_vector(docid), 3);
size_t idx = 0;
for (const auto & hit : rv) {
// fprintf(stderr, "found docid %u dist %.1f\n", hit.docid, hit.distance);
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 85010be1e17..4bed9097a5c 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -310,8 +310,23 @@ HnswIndex::remove_document(uint32_t docid)
_node_refs[docid].store_release(invalid);
}
+std::vector<uint32_t>
+HnswIndex::find_top_k(TypedCells vector, uint32_t k)
+{
+ std::vector<uint32_t> result;
+ std::vector<HnswCandidate> candidates = top_k_candidates(vector, k + 100);
+ result.reserve(std::min((size_t)k, candidates.size()));
+ for (const HnswCandidate & hit : candidates) {
+ result.emplace_back(hit.docid);
+ if (result.size() == k) break;
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+}
+
+
std::vector<HnswCandidate>
-HnswIndex::find_top_k(const TypedCells &vector, uint32_t k)
+HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k)
{
std::vector<HnswCandidate> result;
if (_entry_level < 0) {
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index af20dfd9a78..2454ff85884 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -134,7 +134,8 @@ public:
void add_document(uint32_t docid) override;
void remove_document(uint32_t docid) override;
- std::vector<HnswCandidate> find_top_k(const TypedCells &vector, uint32_t k);
+ std::vector<uint32_t> find_top_k(TypedCells vector, uint32_t k) override;
+ std::vector<HnswCandidate> top_k_candidates(const TypedCells &vector, uint32_t k);
// TODO: Add support for generation handling and cleanup (transfer_hold_lists, trim_hold_lists)
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
index 2167157f6cb..718b96c92b4 100644
--- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
@@ -3,6 +3,8 @@
#pragma once
#include <cstdint>
+#include <vector>
+#include <vespa/eval/tensor/dense/typed_cells.h>
namespace search::tensor {
@@ -14,6 +16,7 @@ public:
virtual ~NearestNeighborIndex() {}
virtual void add_document(uint32_t docid) = 0;
virtual void remove_document(uint32_t docid) = 0;
+ virtual std::vector<uint32_t> find_top_k(vespalib::tensor::TypedCells vector, uint32_t k) = 0;
};
}