diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-05-11 11:19:47 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-05-11 11:20:28 +0000 |
commit | ed62bc8a2d08e4a57ae8cd29f0a8aca0e7dd086e (patch) | |
tree | f978dab2870a25e020ff5365db52eaddecfbcec2 | |
parent | d6e059759286443da0e30abb9212baf3b8c281ab (diff) |
allow filter in HNSW index
4 files changed, 103 insertions, 0 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 18a6a5a8188..592a8aa6a36 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -186,6 +186,16 @@ public: (void) explore_k; return std::vector<Neighbor>(); } + std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::tensor::TypedCells vector, + const search::BitVector& filter, uint32_t explore_k) const override + { + (void) k; + (void) vector; + (void) explore_k; + (void) filter; + return std::vector<Neighbor>(); + } + const search::tensor::DistanceFunction *distance_function() const override { return nullptr; } }; diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 612f30cc64f..7e89c6cd823 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -246,6 +246,51 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, Fur } } +FurthestPriQ +HnswIndex::search_l0(const TypedCells& input, uint32_t neighbors_to_find, + HnswCandidate entry_point, const BitVector &filter) const +{ + FurthestPriQ best_neighbors; + const uint32_t level = 0; + NearestPriQ candidates; + uint32_t doc_id_limit = _graph.node_refs.size(); + auto visited = _visited_set_pool.get(doc_id_limit); + + assert(entry_point.docid < doc_id_limit); + candidates.push(entry_point); + visited.mark(entry_point.docid); + if (filter.testBit(entry_point.docid)) { + best_neighbors.push(entry_point); + } + double limit_dist = std::numeric_limits<double>::max(); + + while (!candidates.empty()) { + auto cand = candidates.top(); + if (cand.distance > limit_dist) { + break; + } + candidates.pop(); + for (uint32_t neighbor_docid : _graph.get_link_array(cand.docid, level)) { + if ((neighbor_docid >= doc_id_limit) || visited.is_marked(neighbor_docid)) { + continue; + } + visited.mark(neighbor_docid); + double dist_to_input = calc_distance(input, neighbor_docid); + if (dist_to_input < limit_dist) { + candidates.emplace(neighbor_docid, dist_to_input); + if (filter.testBit(neighbor_docid)) { + best_neighbors.emplace(neighbor_docid, dist_to_input); + if (best_neighbors.size() > neighbors_to_find) { + best_neighbors.pop(); + limit_dist = best_neighbors.top().distance; + } + } + } + } + } + return best_neighbors; +} + HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, RandomLevelGenerator::UP level_generator, const Config& cfg) : @@ -444,6 +489,23 @@ HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const return result; } +std::vector<NearestNeighborIndex::Neighbor> +HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector, + const BitVector &filter, uint32_t explore_k) const +{ + std::vector<Neighbor> result; + FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter); + while (candidates.size() > k) { + candidates.pop(); + } + result.reserve(candidates.size()); + for (const HnswCandidate & hit : candidates.peek()) { + result.emplace_back(hit.docid, hit.distance); + } + std::sort(result.begin(), result.end(), NeighborsByDocId()); + return result; +} + FurthestPriQ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const { @@ -464,6 +526,24 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const return best_neighbors; } +FurthestPriQ +HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const +{ + if (get_entry_level() < 0) { + FurthestPriQ empty; + return empty; + } + uint32_t entry_docid = get_entry_docid(); + int search_level = get_entry_level(); + double entry_dist = calc_distance(vector, entry_docid); + HnswCandidate entry_point(entry_docid, entry_dist); + while (search_level > 0) { + entry_point = find_nearest_in_layer(vector, entry_point, search_level); + --search_level; + } + return search_l0(vector, k, entry_point, filter); +} + HnswNode HnswIndex::get_node(uint32_t docid) const { diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 6a7496e8696..07da6d0cb1a 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -116,6 +116,8 @@ protected: */ HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const; void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level) const; + FurthestPriQ search_l0(const TypedCells& input, uint32_t neighbors_to_find, + HnswCandidate entry_point, const BitVector &filter) const; public: HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, @@ -136,9 +138,12 @@ public: bool load(const fileutil::LoadedBuffer& buf) override; std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const override; + std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector, + const BitVector &filter, uint32_t explore_k) const override; const DistanceFunction *distance_function() const override { return _distance_func.get(); } FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k) const; + FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const; uint32_t get_entry_docid() const { return _graph.entry_docid; } int32_t get_entry_level() const { return _graph.entry_level; } diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index aca2ce2af66..c2d37f2d59a 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -14,6 +14,8 @@ namespace vespalib::slime { struct Inserter; } namespace search::fileutil { class LoadedBuffer; } +namespace search { class BitVector; } + namespace search::tensor { class NearestNeighborIndexSaver; @@ -53,6 +55,12 @@ public: vespalib::tensor::TypedCells vector, uint32_t explore_k) const = 0; + // only return neighbors where the corresponding filter bit is set + virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k, + vespalib::tensor::TypedCells vector, + const BitVector &filter, + uint32_t explore_k) const = 0; + virtual const DistanceFunction *distance_function() const = 0; }; |