diff options
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 2 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp | 35 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.h | 3 |
3 files changed, 13 insertions, 27 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 7d38be7db4a..88e35a80bc9 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -122,7 +122,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k).peek(); + auto rv = index->top_k_candidates(qv, k, nullptr).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 7e89c6cd823..c0046d80aad 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -477,7 +477,7 @@ std::vector<NearestNeighborIndex::Neighbor> HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const { std::vector<Neighbor> result; - FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k)); + FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), nullptr); while (candidates.size() > k) { candidates.pop(); } @@ -494,7 +494,7 @@ HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector, const BitVector &filter, uint32_t explore_k) const { std::vector<Neighbor> result; - FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter); + FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), &filter); while (candidates.size() > k) { candidates.pop(); } @@ -507,27 +507,7 @@ HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector, } FurthestPriQ -HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const -{ - FurthestPriQ best_neighbors; - if (get_entry_level() < 0) { - return best_neighbors; - } - uint32_t entry_docid = get_entry_docid(); - int search_level = get_entry_level(); - double entry_dist = calc_distance(vector, entry_docid); - HnswCandidate entry_point(entry_docid, entry_dist); - while (search_level > 0) { - entry_point = find_nearest_in_layer(vector, entry_point, search_level); - --search_level; - } - best_neighbors.push(entry_point); - search_layer(vector, k, best_neighbors, 0); - return best_neighbors; -} - -FurthestPriQ -HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const +HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const { if (get_entry_level() < 0) { FurthestPriQ empty; @@ -541,7 +521,14 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVecto entry_point = find_nearest_in_layer(vector, entry_point, search_level); --search_level; } - return search_l0(vector, k, entry_point, filter); + if (filter) { + return search_l0(vector, k, entry_point, *filter); + } else { + FurthestPriQ best_neighbors; + best_neighbors.push(entry_point); + search_layer(vector, k, best_neighbors, 0); + return best_neighbors; + } } HnswNode diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 07da6d0cb1a..68cfc783d87 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -142,8 +142,7 @@ public: const BitVector &filter, uint32_t explore_k) const override; const DistanceFunction *distance_function() const override { return _distance_func.get(); } - FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k) const; - FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const; + FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const; uint32_t get_entry_docid() const { return _graph.entry_docid; } int32_t get_entry_level() const { return _graph.entry_level; } |