aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-11 11:19:47 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-11 11:20:28 +0000
commited62bc8a2d08e4a57ae8cd29f0a8aca0e7dd086e (patch)
treef978dab2870a25e020ff5365db52eaddecfbcec2
parentd6e059759286443da0e30abb9212baf3b8c281ab (diff)
allow filter in HNSW index
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp10
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp80
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h5
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h8
4 files changed, 103 insertions, 0 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 18a6a5a8188..592a8aa6a36 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -186,6 +186,16 @@ public:
(void) explore_k;
return std::vector<Neighbor>();
}
+ std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::tensor::TypedCells vector,
+ const search::BitVector& filter, uint32_t explore_k) const override
+ {
+ (void) k;
+ (void) vector;
+ (void) explore_k;
+ (void) filter;
+ return std::vector<Neighbor>();
+ }
+
const search::tensor::DistanceFunction *distance_function() const override { return nullptr; }
};
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 612f30cc64f..7e89c6cd823 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -246,6 +246,51 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, Fur
}
}
+FurthestPriQ
+HnswIndex::search_l0(const TypedCells& input, uint32_t neighbors_to_find,
+ HnswCandidate entry_point, const BitVector &filter) const
+{
+ FurthestPriQ best_neighbors;
+ const uint32_t level = 0;
+ NearestPriQ candidates;
+ uint32_t doc_id_limit = _graph.node_refs.size();
+ auto visited = _visited_set_pool.get(doc_id_limit);
+
+ assert(entry_point.docid < doc_id_limit);
+ candidates.push(entry_point);
+ visited.mark(entry_point.docid);
+ if (filter.testBit(entry_point.docid)) {
+ best_neighbors.push(entry_point);
+ }
+ double limit_dist = std::numeric_limits<double>::max();
+
+ while (!candidates.empty()) {
+ auto cand = candidates.top();
+ if (cand.distance > limit_dist) {
+ break;
+ }
+ candidates.pop();
+ for (uint32_t neighbor_docid : _graph.get_link_array(cand.docid, level)) {
+ if ((neighbor_docid >= doc_id_limit) || visited.is_marked(neighbor_docid)) {
+ continue;
+ }
+ visited.mark(neighbor_docid);
+ double dist_to_input = calc_distance(input, neighbor_docid);
+ if (dist_to_input < limit_dist) {
+ candidates.emplace(neighbor_docid, dist_to_input);
+ if (filter.testBit(neighbor_docid)) {
+ best_neighbors.emplace(neighbor_docid, dist_to_input);
+ if (best_neighbors.size() > neighbors_to_find) {
+ best_neighbors.pop();
+ limit_dist = best_neighbors.top().distance;
+ }
+ }
+ }
+ }
+ }
+ return best_neighbors;
+}
+
HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
RandomLevelGenerator::UP level_generator, const Config& cfg)
:
@@ -444,6 +489,23 @@ HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const
return result;
}
+std::vector<NearestNeighborIndex::Neighbor>
+HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector,
+ const BitVector &filter, uint32_t explore_k) const
+{
+ std::vector<Neighbor> result;
+ FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter);
+ while (candidates.size() > k) {
+ candidates.pop();
+ }
+ result.reserve(candidates.size());
+ for (const HnswCandidate & hit : candidates.peek()) {
+ result.emplace_back(hit.docid, hit.distance);
+ }
+ std::sort(result.begin(), result.end(), NeighborsByDocId());
+ return result;
+}
+
FurthestPriQ
HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const
{
@@ -464,6 +526,24 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k) const
return best_neighbors;
}
+FurthestPriQ
+HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const
+{
+ if (get_entry_level() < 0) {
+ FurthestPriQ empty;
+ return empty;
+ }
+ uint32_t entry_docid = get_entry_docid();
+ int search_level = get_entry_level();
+ double entry_dist = calc_distance(vector, entry_docid);
+ HnswCandidate entry_point(entry_docid, entry_dist);
+ while (search_level > 0) {
+ entry_point = find_nearest_in_layer(vector, entry_point, search_level);
+ --search_level;
+ }
+ return search_l0(vector, k, entry_point, filter);
+}
+
HnswNode
HnswIndex::get_node(uint32_t docid) const
{
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 6a7496e8696..07da6d0cb1a 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -116,6 +116,8 @@ protected:
*/
HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const;
void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level) const;
+ FurthestPriQ search_l0(const TypedCells& input, uint32_t neighbors_to_find,
+ HnswCandidate entry_point, const BitVector &filter) const;
public:
HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
@@ -136,9 +138,12 @@ public:
bool load(const fileutil::LoadedBuffer& buf) override;
std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const override;
+ std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector,
+ const BitVector &filter, uint32_t explore_k) const override;
const DistanceFunction *distance_function() const override { return _distance_func.get(); }
FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k) const;
+ FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector &filter) const;
uint32_t get_entry_docid() const { return _graph.entry_docid; }
int32_t get_entry_level() const { return _graph.entry_level; }
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
index aca2ce2af66..c2d37f2d59a 100644
--- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
@@ -14,6 +14,8 @@ namespace vespalib::slime { struct Inserter; }
namespace search::fileutil { class LoadedBuffer; }
+namespace search { class BitVector; }
+
namespace search::tensor {
class NearestNeighborIndexSaver;
@@ -53,6 +55,12 @@ public:
vespalib::tensor::TypedCells vector,
uint32_t explore_k) const = 0;
+ // only return neighbors where the corresponding filter bit is set
+ virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k,
+ vespalib::tensor::TypedCells vector,
+ const BitVector &filter,
+ uint32_t explore_k) const = 0;
+
virtual const DistanceFunction *distance_function() const = 0;
};