aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-02-19 12:14:14 +0000
committerArne Juul <arnej@verizonmedia.com>2020-02-19 12:15:00 +0000
commit6d06888aeadc0f10f9c5ac1d0e780fdbb00701aa (patch)
treeee08f6a8ac404f91af59aae264c7d8c3c483e804
parentd3853d707b9fd95a7cdb9d0edf1ad8de245dc049 (diff)
expose explore_k in top level API
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp24
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h4
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h2
4 files changed, 17 insertions, 19 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 1b821a05c84..c6246bb8434 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -89,10 +89,10 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k);
+ auto rv = index->top_k_candidates(qv, k).peek();
+ std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
- // fprintf(stderr, "found docid %u dist %.1f\n", hit.docid, hit.distance);
if (idx < exp_hits.size()) {
EXPECT_EQ(hit.docid, exp_hits[idx++]);
}
@@ -100,7 +100,7 @@ public:
if (exp_hits.size() == k) {
std::vector<uint32_t> expected_by_docid = exp_hits;
std::sort(expected_by_docid.begin(), expected_by_docid.end());
- std::vector<uint32_t> got_by_docid = index->find_top_k(qv, k);
+ std::vector<uint32_t> got_by_docid = index->find_top_k(k, qv, k);
EXPECT_EQ(expected_by_docid, got_by_docid);
}
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 4bed9097a5c..bb3076dfe95 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -311,26 +311,27 @@ HnswIndex::remove_document(uint32_t docid)
}
std::vector<uint32_t>
-HnswIndex::find_top_k(TypedCells vector, uint32_t k)
+HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k)
{
std::vector<uint32_t> result;
- std::vector<HnswCandidate> candidates = top_k_candidates(vector, k + 100);
- result.reserve(std::min((size_t)k, candidates.size()));
- for (const HnswCandidate & hit : candidates) {
+ FurthestPriQ candidates = top_k_candidates(vector, explore_k);
+ while (candidates.size() > k) {
+ candidates.pop();
+ }
+ result.reserve(candidates.size());
+ for (const HnswCandidate & hit : candidates.peek()) {
result.emplace_back(hit.docid);
- if (result.size() == k) break;
}
std::sort(result.begin(), result.end());
return result;
}
-
-std::vector<HnswCandidate>
+FurthestPriQ
HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k)
{
- std::vector<HnswCandidate> result;
+ FurthestPriQ best_neighbors;
if (_entry_level < 0) {
- return result;
+ return best_neighbors;
}
double entry_dist = calc_distance(vector, _entry_docid);
HnswCandidate entry_point(_entry_docid, entry_dist);
@@ -339,12 +340,9 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k)
entry_point = find_nearest_in_layer(vector, entry_point, search_level);
--search_level;
}
- FurthestPriQ best_neighbors;
best_neighbors.push(entry_point);
search_layer(vector, k, best_neighbors, 0);
- result = best_neighbors.peek();
- std::sort(result.begin(), result.end(), LesserDistance());
- return result;
+ return best_neighbors;
}
HnswNode
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 2454ff85884..814148072ca 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -134,8 +134,8 @@ public:
void add_document(uint32_t docid) override;
void remove_document(uint32_t docid) override;
- std::vector<uint32_t> find_top_k(TypedCells vector, uint32_t k) override;
- std::vector<HnswCandidate> top_k_candidates(const TypedCells &vector, uint32_t k);
+ std::vector<uint32_t> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) override;
+ FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k);
// TODO: Add support for generation handling and cleanup (transfer_hold_lists, trim_hold_lists)
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
index 718b96c92b4..2ae322fe76e 100644
--- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
@@ -16,7 +16,7 @@ public:
virtual ~NearestNeighborIndex() {}
virtual void add_document(uint32_t docid) = 0;
virtual void remove_document(uint32_t docid) = 0;
- virtual std::vector<uint32_t> find_top_k(vespalib::tensor::TypedCells vector, uint32_t k) = 0;
+ virtual std::vector<uint32_t> find_top_k(uint32_t k, vespalib::tensor::TypedCells vector, uint32_t explore_k) = 0;
};
}