diff options
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp | 11 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.h | 2 |
2 files changed, 8 insertions, 5 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index f79727cd6bd..09608a9abbe 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -246,11 +246,12 @@ void HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& best_neighbors, uint32_t level) const { NearestPriQ candidates; - // TODO: Add proper handling of visited set. - auto visited = BitVector::create(_node_refs.size()); + uint32_t doc_id_limit = _node_refs.size(); + auto visited = _visited_set_pool.get(doc_id_limit); for (const auto &entry : best_neighbors.peek()) { + assert(entry.docid < doc_id_limit); candidates.push(entry); - visited->setBit(entry.docid); + visited.mark(entry.docid); } double limit_dist = std::numeric_limits<double>::max(); @@ -261,10 +262,10 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, Fur } candidates.pop(); for (uint32_t neighbor_docid : get_link_array(cand.docid, level)) { - if (visited->testBit(neighbor_docid)) { + if ((neighbor_docid >= doc_id_limit) || visited.is_marked(neighbor_docid)) { continue; } - visited->setBit(neighbor_docid); + visited.mark(neighbor_docid); double dist_to_input = calc_distance(input, neighbor_docid); if (dist_to_input < limit_dist) { candidates.emplace(neighbor_docid, dist_to_input); diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index ca87047f61b..4066316a991 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -14,6 +14,7 @@ #include <vespa/vespalib/datastore/atomic_entry_ref.h> #include <vespa/vespalib/datastore/entryref.h> #include <vespa/vespalib/util/rcuvector.h> +#include <vespa/vespalib/util/reusable_set_pool.h> namespace search::tensor { @@ -87,6 +88,7 @@ protected: NodeRefVector _node_refs; NodeStore _nodes; LinkStore _links; + mutable vespalib::ReusableSetPool _visited_set_pool; uint32_t _entry_docid; int _entry_level; |