summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-01-17 15:09:30 +0100
committerTor Egge <Tor.Egge@online.no>2023-01-17 15:09:30 +0100
commitbc73414ce1c04f6c4e1c8105c4b95fabb618ef14 (patch)
treea942acf229ce98a7c19b5b80c79598fae2b2d4bc /searchlib
parenta02f6dede476db726235d22a42ca38523fe9493d (diff)
Pass range checked docid to check member function on global filter.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp31
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp46
2 files changed, 71 insertions, 6 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 34085aad112..dd248b07fff 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -165,7 +165,9 @@ public:
uint32_t explore_k = 100;
vespalib::ArrayRef qv_ref(qv);
vespalib::eval::TypedCells qv_cells(qv_ref);
- auto got_by_docid = index->find_top_k(k, qv_cells, explore_k, 10000.0);
+ auto got_by_docid = (global_filter->is_active()) ?
+ index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) :
+ index->find_top_k(k, qv_cells, explore_k, 10000.0);
std::vector<uint32_t> act;
act.reserve(got_by_docid.size());
for (auto& hit : got_by_docid) {
@@ -760,6 +762,29 @@ TYPED_TEST(HnswIndexTest, hnsw_graph_can_be_saved_and_loaded)
using HnswMultiIndexTest = HnswIndexTest<HnswIndex<HnswIndexType::MULTI>>;
+namespace {
+
+class MyGlobalFilter : public GlobalFilter {
+ std::shared_ptr<GlobalFilter> _filter;
+ mutable uint32_t _max_docid;
+public:
+ MyGlobalFilter(std::shared_ptr<GlobalFilter> filter)
+ : _filter(std::move(filter)),
+ _max_docid(0)
+ {
+ }
+ bool is_active() const override { return _filter->is_active(); }
+ uint32_t size() const override { return _filter->size(); }
+ uint32_t count() const override { return _filter->count(); }
+ bool check(uint32_t docid) const override {
+ _max_docid = std::max(_max_docid, docid);
+ return _filter->check(docid);
+ }
+ uint32_t max_docid() const noexcept { return _max_docid; }
+};
+
+}
+
TEST_F(HnswMultiIndexTest, duplicate_docid_is_removed)
{
this->init(false);
@@ -786,6 +811,10 @@ TEST_F(HnswMultiIndexTest, duplicate_docid_is_removed)
this->expect_top_3_by_docid("{2, 0}", {2, 0}, {1, 2, 4});
this->expect_top_3_by_docid("{2, 1}", {2, 1}, {2, 3, 4});
this->expect_top_3_by_docid("{2, 2}", {2, 2}, {1, 3, 4});
+ auto filter = std::make_shared<MyGlobalFilter>(GlobalFilter::create({1, 2}, 3));
+ global_filter = filter;
+ this->expect_top_3_by_docid("{2,2}", {2, 2}, {1, 2});
+ EXPECT_EQ(2, filter->max_docid());
};
TEST(LevelGeneratorTest, gives_various_levels)
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 1529611753b..03759e8a5cc 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -28,6 +28,7 @@ namespace search::tensor {
using search::AddressSpaceComponents;
using search::StateExplorerUtils;
+using search::queryeval::GlobalFilter;
using vespalib::datastore::CompactionStrategy;
using vespalib::datastore::EntryRef;
@@ -59,6 +60,42 @@ bool operator< (const PairDist &a, const PairDist &b) {
return (a.distance < b.distance);
}
+template <HnswIndexType type>
+class GlobalFilterWrapper;
+
+template <>
+class GlobalFilterWrapper<HnswIndexType::SINGLE> {
+ const GlobalFilter *_filter;
+public:
+ GlobalFilterWrapper(const GlobalFilter *filter)
+ : _filter(filter)
+ {
+ }
+
+ bool check(uint32_t docid) const noexcept { return !_filter || _filter->check(docid); }
+
+ void clamp_nodeid_limit(uint32_t& nodeid_limit) {
+ if (_filter) {
+ nodeid_limit = std::min(nodeid_limit, _filter->size());
+ }
+ }
+};
+
+template <>
+class GlobalFilterWrapper<HnswIndexType::MULTI> {
+ const GlobalFilter *_filter;
+ uint32_t _docid_limit;
+public:
+ GlobalFilterWrapper(const GlobalFilter *filter)
+ : _filter(filter),
+ _docid_limit(filter ? filter->size() : 0u)
+ {
+ }
+
+ bool check(uint32_t docid) const noexcept { return !_filter || (docid < _docid_limit && _filter->check(docid)); }
+ static void clamp_nodeid_limit(uint32_t&) { }
+};
+
}
template <HnswIndexType type>
@@ -297,6 +334,8 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors
uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const
{
NearestPriQ candidates;
+ GlobalFilterWrapper<type> filter_wrapper(filter);
+ filter_wrapper.clamp_nodeid_limit(nodeid_limit);
VisitedTracker visited(nodeid_limit, estimated_visited_nodes);
for (const auto &entry : best_neighbors.peek()) {
if (entry.nodeid >= nodeid_limit) {
@@ -304,7 +343,7 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors
}
candidates.push(entry);
visited.mark(entry.nodeid);
- if (filter && !filter->check(entry.nodeid)) {
+ if (!filter_wrapper.check(entry.docid)) {
assert(best_neighbors.peek().size() == 1);
best_neighbors.pop();
}
@@ -333,7 +372,7 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors
double dist_to_input = calc_distance(input, neighbor_docid, neighbor_subspace);
if (dist_to_input < limit_dist) {
candidates.emplace(neighbor_nodeid, neighbor_ref, dist_to_input);
- if ((!filter) || filter->check(neighbor_nodeid)) {
+ if (filter_wrapper.check(neighbor_docid)) {
best_neighbors.emplace(neighbor_nodeid, neighbor_docid, neighbor_ref, dist_to_input);
while (best_neighbors.size() > neighbors_to_find) {
best_neighbors.pop();
@@ -352,9 +391,6 @@ HnswIndex<type>::search_layer(const TypedCells& input, uint32_t neighbors_to_fin
BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter) const
{
uint32_t nodeid_limit = _graph.nodes_size.load(std::memory_order_acquire);
- if (filter) {
- nodeid_limit = std::min(filter->size(), nodeid_limit);
- }
uint32_t estimated_visited_nodes = estimate_visited_nodes(level, nodeid_limit, neighbors_to_find, filter);
if (estimated_visited_nodes >= nodeid_limit / 128) {
search_layer_helper<BitVectorVisitedTracker>(input, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes);