diff options
author | Tor Egge <Tor.Egge@online.no> | 2023-01-17 15:09:30 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2023-01-17 15:09:30 +0100 |
commit | bc73414ce1c04f6c4e1c8105c4b95fabb618ef14 (patch) | |
tree | a942acf229ce98a7c19b5b80c79598fae2b2d4bc /searchlib | |
parent | a02f6dede476db726235d22a42ca38523fe9493d (diff) |
Pass range checked docid to check member function on global filter.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 31 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp | 46 |
2 files changed, 71 insertions, 6 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 34085aad112..dd248b07fff 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -165,7 +165,9 @@ public: uint32_t explore_k = 100; vespalib::ArrayRef qv_ref(qv); vespalib::eval::TypedCells qv_cells(qv_ref); - auto got_by_docid = index->find_top_k(k, qv_cells, explore_k, 10000.0); + auto got_by_docid = (global_filter->is_active()) ? + index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) : + index->find_top_k(k, qv_cells, explore_k, 10000.0); std::vector<uint32_t> act; act.reserve(got_by_docid.size()); for (auto& hit : got_by_docid) { @@ -760,6 +762,29 @@ TYPED_TEST(HnswIndexTest, hnsw_graph_can_be_saved_and_loaded) using HnswMultiIndexTest = HnswIndexTest<HnswIndex<HnswIndexType::MULTI>>; +namespace { + +class MyGlobalFilter : public GlobalFilter { + std::shared_ptr<GlobalFilter> _filter; + mutable uint32_t _max_docid; +public: + MyGlobalFilter(std::shared_ptr<GlobalFilter> filter) + : _filter(std::move(filter)), + _max_docid(0) + { + } + bool is_active() const override { return _filter->is_active(); } + uint32_t size() const override { return _filter->size(); } + uint32_t count() const override { return _filter->count(); } + bool check(uint32_t docid) const override { + _max_docid = std::max(_max_docid, docid); + return _filter->check(docid); + } + uint32_t max_docid() const noexcept { return _max_docid; } +}; + +} + TEST_F(HnswMultiIndexTest, duplicate_docid_is_removed) { this->init(false); @@ -786,6 +811,10 @@ TEST_F(HnswMultiIndexTest, duplicate_docid_is_removed) this->expect_top_3_by_docid("{2, 0}", {2, 0}, {1, 2, 4}); this->expect_top_3_by_docid("{2, 1}", {2, 1}, {2, 3, 4}); this->expect_top_3_by_docid("{2, 2}", {2, 2}, {1, 3, 4}); + auto filter = std::make_shared<MyGlobalFilter>(GlobalFilter::create({1, 2}, 3)); + global_filter = filter; + this->expect_top_3_by_docid("{2,2}", {2, 2}, {1, 2}); + EXPECT_EQ(2, filter->max_docid()); }; TEST(LevelGeneratorTest, gives_various_levels) diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 1529611753b..03759e8a5cc 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -28,6 +28,7 @@ namespace search::tensor { using search::AddressSpaceComponents; using search::StateExplorerUtils; +using search::queryeval::GlobalFilter; using vespalib::datastore::CompactionStrategy; using vespalib::datastore::EntryRef; @@ -59,6 +60,42 @@ bool operator< (const PairDist &a, const PairDist &b) { return (a.distance < b.distance); } +template <HnswIndexType type> +class GlobalFilterWrapper; + +template <> +class GlobalFilterWrapper<HnswIndexType::SINGLE> { + const GlobalFilter *_filter; +public: + GlobalFilterWrapper(const GlobalFilter *filter) + : _filter(filter) + { + } + + bool check(uint32_t docid) const noexcept { return !_filter || _filter->check(docid); } + + void clamp_nodeid_limit(uint32_t& nodeid_limit) { + if (_filter) { + nodeid_limit = std::min(nodeid_limit, _filter->size()); + } + } +}; + +template <> +class GlobalFilterWrapper<HnswIndexType::MULTI> { + const GlobalFilter *_filter; + uint32_t _docid_limit; +public: + GlobalFilterWrapper(const GlobalFilter *filter) + : _filter(filter), + _docid_limit(filter ? filter->size() : 0u) + { + } + + bool check(uint32_t docid) const noexcept { return !_filter || (docid < _docid_limit && _filter->check(docid)); } + static void clamp_nodeid_limit(uint32_t&) { } +}; + } template <HnswIndexType type> @@ -297,6 +334,8 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors uint32_t nodeid_limit, uint32_t estimated_visited_nodes) const { NearestPriQ candidates; + GlobalFilterWrapper<type> filter_wrapper(filter); + filter_wrapper.clamp_nodeid_limit(nodeid_limit); VisitedTracker visited(nodeid_limit, estimated_visited_nodes); for (const auto &entry : best_neighbors.peek()) { if (entry.nodeid >= nodeid_limit) { @@ -304,7 +343,7 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors } candidates.push(entry); visited.mark(entry.nodeid); - if (filter && !filter->check(entry.nodeid)) { + if (!filter_wrapper.check(entry.docid)) { assert(best_neighbors.peek().size() == 1); best_neighbors.pop(); } @@ -333,7 +372,7 @@ HnswIndex<type>::search_layer_helper(const TypedCells& input, uint32_t neighbors double dist_to_input = calc_distance(input, neighbor_docid, neighbor_subspace); if (dist_to_input < limit_dist) { candidates.emplace(neighbor_nodeid, neighbor_ref, dist_to_input); - if ((!filter) || filter->check(neighbor_nodeid)) { + if (filter_wrapper.check(neighbor_docid)) { best_neighbors.emplace(neighbor_nodeid, neighbor_docid, neighbor_ref, dist_to_input); while (best_neighbors.size() > neighbors_to_find) { best_neighbors.pop(); @@ -352,9 +391,6 @@ HnswIndex<type>::search_layer(const TypedCells& input, uint32_t neighbors_to_fin BestNeighbors& best_neighbors, uint32_t level, const GlobalFilter *filter) const { uint32_t nodeid_limit = _graph.nodes_size.load(std::memory_order_acquire); - if (filter) { - nodeid_limit = std::min(filter->size(), nodeid_limit); - } uint32_t estimated_visited_nodes = estimate_visited_nodes(level, nodeid_limit, neighbors_to_find, filter); if (estimated_visited_nodes >= nodeid_limit / 128) { search_layer_helper<BitVectorVisitedTracker>(input, neighbors_to_find, best_neighbors, level, filter, nodeid_limit, estimated_visited_nodes); |