summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-11 14:07:46 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-11 14:07:46 +0000
commitc0d652b21938dedd9d076131c90460e52e2cf9e4 (patch)
treed51fd41f5b0bd1c2180a70fd64d6efd335e3d8e1 /searchlib
parent8020c63c052f207673497906dbcc311f248a90e0 (diff)
de-duplicate find_top_k code
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp23
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h3
2 files changed, 13 insertions, 13 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index c0046d80aad..50e6e70405e 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -474,10 +474,11 @@ struct NeighborsByDocId {
};
std::vector<NearestNeighborIndex::Neighbor>
-HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const
+HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector,
+ const BitVector *filter, uint32_t explore_k) const
{
std::vector<Neighbor> result;
- FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), nullptr);
+ FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter);
while (candidates.size() > k) {
candidates.pop();
}
@@ -490,20 +491,16 @@ HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const
}
std::vector<NearestNeighborIndex::Neighbor>
+HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const
+{
+ return top_k_by_docid(k, vector, nullptr, explore_k);
+}
+
+std::vector<NearestNeighborIndex::Neighbor>
HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector,
const BitVector &filter, uint32_t explore_k) const
{
- std::vector<Neighbor> result;
- FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), &filter);
- while (candidates.size() > k) {
- candidates.pop();
- }
- result.reserve(candidates.size());
- for (const HnswCandidate & hit : candidates.peek()) {
- result.emplace_back(hit.docid, hit.distance);
- }
- std::sort(result.begin(), result.end(), NeighborsByDocId());
- return result;
+ return top_k_by_docid(k, vector, &filter, explore_k);
}
FurthestPriQ
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 68cfc783d87..252ed01bff0 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -119,6 +119,9 @@ protected:
FurthestPriQ search_l0(const TypedCells& input, uint32_t neighbors_to_find,
HnswCandidate entry_point, const BitVector &filter) const;
+ std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector,
+ const BitVector *filter, uint32_t explore_k) const;
+
public:
HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
RandomLevelGenerator::UP level_generator, const Config& cfg);