diff options
4 files changed, 17 insertions, 19 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 1b821a05c84..c6246bb8434 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -89,10 +89,10 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k); + auto rv = index->top_k_candidates(qv, k).peek(); + std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { - // fprintf(stderr, "found docid %u dist %.1f\n", hit.docid, hit.distance); if (idx < exp_hits.size()) { EXPECT_EQ(hit.docid, exp_hits[idx++]); } @@ -100,7 +100,7 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - std::vector<uint32_t> got_by_docid = index->find_top_k(qv, k); + std::vector<uint32_t> got_by_docid = index->find_top_k(k, qv, k); EXPECT_EQ(expected_by_docid, got_by_docid); } } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 4bed9097a5c..bb3076dfe95 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -311,26 +311,27 @@ HnswIndex::remove_document(uint32_t docid) } std::vector<uint32_t> -HnswIndex::find_top_k(TypedCells vector, uint32_t k) +HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) { std::vector<uint32_t> result; - std::vector<HnswCandidate> candidates = top_k_candidates(vector, k + 100); - result.reserve(std::min((size_t)k, candidates.size())); - for (const HnswCandidate & hit : candidates) { + FurthestPriQ candidates = top_k_candidates(vector, explore_k); + while (candidates.size() > k) { + candidates.pop(); + } + result.reserve(candidates.size()); + for (const HnswCandidate & hit : candidates.peek()) { result.emplace_back(hit.docid); - if (result.size() == k) break; } std::sort(result.begin(), result.end()); return result; } - -std::vector<HnswCandidate> +FurthestPriQ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) { - std::vector<HnswCandidate> result; + FurthestPriQ best_neighbors; if (_entry_level < 0) { - return result; + return best_neighbors; } double entry_dist = calc_distance(vector, _entry_docid); HnswCandidate entry_point(_entry_docid, entry_dist); @@ -339,12 +340,9 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) entry_point = find_nearest_in_layer(vector, entry_point, search_level); --search_level; } - FurthestPriQ best_neighbors; best_neighbors.push(entry_point); search_layer(vector, k, best_neighbors, 0); - result = best_neighbors.peek(); - std::sort(result.begin(), result.end(), LesserDistance()); - return result; + return best_neighbors; } HnswNode diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 2454ff85884..814148072ca 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -134,8 +134,8 @@ public: void add_document(uint32_t docid) override; void remove_document(uint32_t docid) override; - std::vector<uint32_t> find_top_k(TypedCells vector, uint32_t k) override; - std::vector<HnswCandidate> top_k_candidates(const TypedCells &vector, uint32_t k); + std::vector<uint32_t> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) override; + FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k); // TODO: Add support for generation handling and cleanup (transfer_hold_lists, trim_hold_lists) diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index 718b96c92b4..2ae322fe76e 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -16,7 +16,7 @@ public: virtual ~NearestNeighborIndex() {} virtual void add_document(uint32_t docid) = 0; virtual void remove_document(uint32_t docid) = 0; - virtual std::vector<uint32_t> find_top_k(vespalib::tensor::TypedCells vector, uint32_t k) = 0; + virtual std::vector<uint32_t> find_top_k(uint32_t k, vespalib::tensor::TypedCells vector, uint32_t explore_k) = 0; }; } |