summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-11 13:56:45 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-11 13:56:45 +0000
commit8020c63c052f207673497906dbcc311f248a90e0 (patch)
treeafef12826377732de5f4fb81dbca56f840431068 /searchlib
parented62bc8a2d08e4a57ae8cd29f0a8aca0e7dd086e (diff)
collapse top_k_candidates
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp35
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h3
3 files changed, 13 insertions, 27 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 7d38be7db4a..88e35a80bc9 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -122,7 +122,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k).peek();
+ auto rv = index->top_k_candidates(qv, k, nullptr).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 7e89c6cd823..c0046d80aad 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -477,7 +477,7 @@ std::vector<NearestNeighborIndex::Neighbor>
HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const
{
std::vector<Neighbor> result;
- FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k));
+ FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), nullptr);
while (candidates.size() > k) {
candidates.pop();
}
@@ -494,7 +494,7 @@ HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector,
const BitVector &filter, uint32_t explore_k) const
{
std::vector<Neighbor> result;
- FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter);
+ FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), &filter);
while (candidates.size() > k) {
candidates.pop();
}
@@ -507,27 +507,7 @@ HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector,
}
FurthestPriQ
-HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const
-{
- FurthestPriQ best_neighbors;
- if (get_entry_level() < 0) {
- return best_neighbors;
- }
- uint32_t entry_docid = get_entry_docid();
- int search_level = get_entry_level();
- double entry_dist = calc_distance(vector, entry_docid);
- HnswCandidate entry_point(entry_docid, entry_dist);
- while (search_level > 0) {
- entry_point = find_nearest_in_layer(vector, entry_point, search_level);
- --search_level;
- }
- best_neighbors.push(entry_point);
- search_layer(vector, k, best_neighbors, 0);
- return best_neighbors;
-}
-
-FurthestPriQ
-HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const
+HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const
{
if (get_entry_level() < 0) {
FurthestPriQ empty;
@@ -541,7 +521,14 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVecto
entry_point = find_nearest_in_layer(vector, entry_point, search_level);
--search_level;
}
- return search_l0(vector, k, entry_point, filter);
+ if (filter) {
+ return search_l0(vector, k, entry_point, *filter);
+ } else {
+ FurthestPriQ best_neighbors;
+ best_neighbors.push(entry_point);
+ search_layer(vector, k, best_neighbors, 0);
+ return best_neighbors;
+ }
}
HnswNode
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 07da6d0cb1a..68cfc783d87 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -142,8 +142,7 @@ public:
const BitVector &filter, uint32_t explore_k) const override;
const DistanceFunction *distance_function() const override { return _distance_func.get(); }
- FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k) const;
- FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const;
+ FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const;
uint32_t get_entry_docid() const { return _graph.entry_docid; }
int32_t get_entry_level() const { return _graph.entry_level; }