diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-09-12 13:21:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-12 13:21:52 +0200 |
commit | 33c97cd5a14070178a1499fb7c3abe2e00e663fa (patch) | |
tree | c9184d8fa429b53afe09119406595aab2ed067fb | |
parent | 2823e2353618573aa0513309e07cbd3c62d519f1 (diff) | |
parent | 58fbe6f2e0d30ba239036987134a246828246542 (diff) |
Merge pull request #24011 from vespa-engine/havardpe/global-filter-as-interface
GlobalFilter is now an interface
13 files changed, 118 insertions, 83 deletions
diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp index eacb4108686..78bb679a7dc 100644 --- a/searchcore/src/tests/proton/matching/query_test.cpp +++ b/searchcore/src/tests/proton/matching/query_test.cpp @@ -1158,21 +1158,20 @@ Test::global_filter_is_calculated_and_handled() auto res = Query::handle_global_filter(bp, docid_limit, 0, 0.3, nullptr); EXPECT_TRUE(res); EXPECT_TRUE(bp.filter); - EXPECT_TRUE(bp.filter->has_filter()); + EXPECT_TRUE(bp.filter->is_active()); EXPECT_EQUAL(0.3, bp.estimated_hit_ratio); - auto* bv = bp.filter->filter(); - EXPECT_EQUAL(3u, bv->countTrueBits()); - EXPECT_TRUE(bv->testBit(3)); - EXPECT_TRUE(bv->testBit(5)); - EXPECT_TRUE(bv->testBit(7)); + EXPECT_EQUAL(3u, bp.filter->count()); + EXPECT_TRUE(bp.filter->check(3)); + EXPECT_TRUE(bp.filter->check(5)); + EXPECT_TRUE(bp.filter->check(7)); } { // estimated_hit_ratio > global_filter_upper_limit GlobalFilterBlueprint bp(result, true); auto res = Query::handle_global_filter(bp, docid_limit, 0, 0.29, nullptr); EXPECT_TRUE(res); EXPECT_TRUE(bp.filter); - EXPECT_FALSE(bp.filter->has_filter()); + EXPECT_FALSE(bp.filter->is_active()); EXPECT_EQUAL(0.3, bp.estimated_hit_ratio); } } diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index b93398e16a1..ebe96035fc8 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -260,7 +260,7 @@ public: return std::vector<Neighbor>(); } std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, - const search::BitVector& filter, uint32_t explore_k, + const GlobalFilter& filter, uint32_t explore_k, double distance_threshold) const override { (void) k; diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index 87de62dbfad..2379213b87b 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -80,7 +80,7 @@ TEST("test AndNot Blueprint") { a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create())); EXPECT_EQUAL(true, a.getState().want_global_filter()); auto empty_global_filter = GlobalFilter::create(); - EXPECT_FALSE(empty_global_filter->has_filter()); + EXPECT_FALSE(empty_global_filter->is_active()); a.set_global_filter(*empty_global_filter, 1.0); EXPECT_EQUAL(false, got_global_filter(a.getChild(0))); EXPECT_EQUAL(true, got_global_filter(a.getChild(1))); diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 1e341eab707..33435f43618 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -9,6 +9,7 @@ #include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h> #include <vespa/searchlib/queryeval/nns_index_iterator.h> #include <vespa/searchlib/queryeval/simpleresult.h> +#include <vespa/searchlib/queryeval/global_filter.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/distance_function_factory.h> @@ -63,7 +64,7 @@ struct Fixture vespalib::string _typeSpec; std::shared_ptr<DenseTensorAttribute> _tensorAttr; std::shared_ptr<AttributeVector> _attr; - std::unique_ptr<BitVector> _global_filter; + std::shared_ptr<GlobalFilter> _global_filter; Fixture(const vespalib::string &typeSpec) : _cfg(BasicType::TENSOR, CollectionType::SINGLE), @@ -71,7 +72,7 @@ struct Fixture _typeSpec(typeSpec), _tensorAttr(), _attr(), - _global_filter() + _global_filter(GlobalFilter::create()) { _cfg.setTensorType(ValueType::from_spec(typeSpec)); _tensorAttr = makeAttr(); @@ -95,11 +96,12 @@ struct Fixture void setFilter(std::vector<uint32_t> docids) { uint32_t sz = _attr->getNumDocs(); - _global_filter = BitVector::create(sz); + auto bit_vector = BitVector::create(sz); for (uint32_t id : docids) { EXPECT_LT(id, sz); - _global_filter->setBit(id); + bit_vector->setBit(id); } + _global_filter = GlobalFilter::create(std::move(bit_vector)); } void setTensor(uint32_t docId, const Value &tensor) { @@ -130,7 +132,7 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); - const BitVector *filter = env._global_filter.get(); + const GlobalFilter &filter = *env._global_filter; auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); @@ -222,7 +224,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto &attr = *(env._tensorAttr); DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr); + auto dummy_filter = GlobalFilter::create(); + auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 3d1127e6bc4..193bb04843c 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/tensor/hnsw_index.h> #include <vespa/searchlib/tensor/random_level_generator.h> #include <vespa/searchlib/tensor/inv_log_level_generator.h> +#include <vespa/searchlib/queryeval/global_filter.h> #include <vespa/vespalib/datastore/compaction_spec.h> #include <vespa/vespalib/datastore/compaction_strategy.h> #include <vespa/vespalib/gtest/gtest.h> @@ -24,6 +25,7 @@ using vespalib::Slime; using search::BitVector; using vespalib::datastore::CompactionSpec; using vespalib::datastore::CompactionStrategy; +using search::queryeval::GlobalFilter; template <typename FloatType> class MyDocVectorAccess : public DocVectorAccess { @@ -61,14 +63,14 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>; class HnswIndexTest : public ::testing::Test { public: FloatVectors vectors; - std::unique_ptr<BitVector> global_filter; + std::shared_ptr<GlobalFilter> global_filter; LevelGenerator* level_generator; GenerationHandler gen_handler; HnswIndexUP index; HnswIndexTest() : vectors(), - global_filter(), + global_filter(GlobalFilter::create()), level_generator(), gen_handler(), index() @@ -80,6 +82,10 @@ public: ~HnswIndexTest() {} + const GlobalFilter *global_filter_ptr() const { + return global_filter->is_active() ? global_filter.get() : nullptr; + } + void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); @@ -104,11 +110,12 @@ public: } void set_filter(std::vector<uint32_t> docids) { uint32_t sz = 10; - global_filter = BitVector::create(sz); + auto bit_vector = BitVector::create(sz); for (uint32_t id : docids) { EXPECT_LT(id, sz); - global_filter->setBit(id); + bit_vector->setBit(id); } + global_filter = GlobalFilter::create(std::move(bit_vector)); } GenerationHandler::Guard take_read_guard() { return gen_handler.takeGuard(); @@ -142,7 +149,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -163,12 +170,12 @@ public: void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); double thr = (rv[0].distance + rv[1].distance) * 0.5; - auto got_by_docid = (global_filter) + auto got_by_docid = (global_filter->is_active()) ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr) : index->find_top_k(k, qv, k, thr); EXPECT_EQ(got_by_docid.size(), 1); diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp index b1995c7ab1c..1a5d3d3dacd 100644 --- a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp @@ -1,3 +1,41 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "global_filter.h" + +namespace search::queryeval { + +namespace { + +struct Inactive : GlobalFilter { + bool is_active() const override { return false; } + uint32_t size() const override { abort(); } + uint32_t count() const override { abort(); } + bool check(uint32_t) const override { abort(); } +}; + +struct BitVectorFilter : public GlobalFilter { + std::unique_ptr<BitVector> vector; + BitVectorFilter(std::unique_ptr<BitVector> vector_in) + : vector(std::move(vector_in)) {} + bool is_active() const override { return true; } + uint32_t size() const override { return vector->size(); } + uint32_t count() const override { return vector->countTrueBits(); } + bool check(uint32_t docid) const override { return vector->testBit(docid); } +}; + +} + +GlobalFilter::GlobalFilter() = default; +GlobalFilter::~GlobalFilter() = default; + +std::shared_ptr<GlobalFilter> +GlobalFilter::create() { + return std::make_shared<Inactive>(); +} + +std::shared_ptr<GlobalFilter> +GlobalFilter::create(std::unique_ptr<BitVector> vector) { + return std::make_shared<BitVectorFilter>(std::move(vector)); +} + +} diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h index 9a2a77ed119..c6e08d5018d 100644 --- a/searchlib/src/vespa/searchlib/queryeval/global_filter.h +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h @@ -8,39 +8,26 @@ namespace search::queryeval { /** - * Hold ownership of a global filter that can be taken - * into account by adaptive query operators. The owned - * bitvector should be a white-list (documents that may - * possibly become hits have their bit set, documents - * that are certain to be filtered away should have theirs - * cleared). + * Hold ownership of a global filter that can be taken into account by + * adaptive query operators. The owned 'bitvector' should be a + * white-list (documents that may possibly become hits have their bit + * set, documents that are certain to be filtered away should have + * theirs cleared). **/ class GlobalFilter : public std::enable_shared_from_this<GlobalFilter> { -private: - struct ctor_tag {}; - std::unique_ptr<search::BitVector> bit_vector; - public: + GlobalFilter(); GlobalFilter(const GlobalFilter &) = delete; GlobalFilter(GlobalFilter &&) = delete; - - GlobalFilter(ctor_tag, std::unique_ptr<search::BitVector> bit_vector_in) noexcept - : bit_vector(std::move(bit_vector_in)) - {} - - GlobalFilter(ctor_tag) noexcept : bit_vector() {} - - ~GlobalFilter() {} - - template<typename ... Params> - static std::shared_ptr<GlobalFilter> create(Params&& ... params) { - return std::make_shared<GlobalFilter>(ctor_tag(), std::forward<Params>(params)...); - } - - const search::BitVector *filter() const { return bit_vector.get(); } - - bool has_filter() const { return bool(bit_vector); } + virtual bool is_active() const = 0; + virtual uint32_t size() const = 0; + virtual uint32_t count() const = 0; + virtual bool check(uint32_t docid) const = 0; + virtual ~GlobalFilter(); + + static std::shared_ptr<GlobalFilter> create(); + static std::shared_ptr<GlobalFilter> create(std::unique_ptr<BitVector> vector); }; } // namespace diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 6a891341afd..993156e04e6 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -80,8 +80,8 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter, d auto nns_index = _attr_tensor.nearest_neighbor_index(); if (_approximate && nns_index) { uint32_t est_hits = _attr_tensor.get_num_docs(); - if (_global_filter->has_filter()) { // pre-filtering case - _global_filter_hits = _global_filter->filter()->countTrueBits(); + if (_global_filter->is_active()) { // pre-filtering case + _global_filter_hits = _global_filter->count(); _global_filter_hit_ratio = static_cast<double>(_global_filter_hits.value()) / est_hits; if (_global_filter_hit_ratio.value() < _global_filter_lower_limit) { _algorithm = Algorithm::EXACT_FALLBACK; @@ -108,9 +108,8 @@ NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborInd { auto lhs = _query_tensor.cells(); uint32_t k = _adjusted_target_hits; - if (_global_filter->has_filter()) { - auto filter = _global_filter->filter(); - _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold); + if (_global_filter->is_active()) { + _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits, _distance_threshold); _algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER; } else { _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold); @@ -131,7 +130,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData ; } return NearestNeighborIterator::create(strict, tfmd, *_distance_calc, - _distance_heap, _global_filter->filter()); + _distance_heap, *_global_filter); } void @@ -151,7 +150,7 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const visitor.openStruct("global_filter", "GlobalFilter"); visitor.visitBool("wanted", getState().want_global_filter()); visitor.visitBool("set", _global_filter_set); - visitor.visitBool("calculated", _global_filter->has_filter()); + visitor.visitBool("calculated", _global_filter->is_active()); visitor.visitFloat("lower_limit", _global_filter_lower_limit); visitor.visitFloat("upper_limit", _global_filter_upper_limit); if (_global_filter_hits.has_value()) { diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index e06fcc614d8..b3f8195676d 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "nearest_neighbor_iterator.h" +#include "global_filter.h" #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/distance_function.h> @@ -47,7 +48,7 @@ public: void doSeek(uint32_t docId) override { double distanceLimit = params().distanceHeap.distanceLimit(); while (__builtin_expect((docId < getEndId()), true)) { - if ((!has_filter) || params().filter->testBit(docId)) { + if ((!has_filter) || params().filter.check(docId)) { double d = computeDistance(docId, distanceLimit); if (d <= distanceLimit) { _lastScore = d; @@ -106,11 +107,10 @@ NearestNeighborIterator::create( fef::TermFieldMatchData &tfmd, const search::tensor::DistanceCalculator &distance_calc, NearestNeighborDistanceHeap &distanceHeap, - const search::BitVector *filter) - + const GlobalFilter &filter) { Params params(tfmd, distance_calc, distanceHeap, filter); - if (filter) { + if (filter.is_active()) { return resolve_strict<true>(strict, params); } else { return resolve_strict<false>(strict, params); diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h index 0d8f70d15c2..f06e62f9cc1 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h @@ -14,6 +14,8 @@ namespace search::tensor { class DistanceCalculator; } namespace search::queryeval { +class GlobalFilter; + class NearestNeighborIterator : public SearchIterator { public: @@ -24,12 +26,12 @@ public: fef::TermFieldMatchData &tfmd; const search::tensor::DistanceCalculator &distance_calc; NearestNeighborDistanceHeap &distanceHeap; - const search::BitVector *filter; + const GlobalFilter &filter; Params(fef::TermFieldMatchData &tfmd_in, const search::tensor::DistanceCalculator &distance_calc_in, NearestNeighborDistanceHeap &distanceHeap_in, - const search::BitVector *filter_in) + const GlobalFilter &filter_in) : tfmd(tfmd_in), distance_calc(distance_calc_in), distanceHeap(distanceHeap_in), @@ -46,7 +48,7 @@ public: fef::TermFieldMatchData &tfmd, const search::tensor::DistanceCalculator &distance_calc, NearestNeighborDistanceHeap &distanceHeap, - const search::BitVector *filter); + const GlobalFilter &filter); const Params& params() const { return _params; } private: diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 2ee1b268449..fa6c9a347aa 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -9,6 +9,7 @@ #include "random_level_generator.h" #include <vespa/searchlib/attribute/address_space_components.h> #include <vespa/searchlib/attribute/address_space_usage.h> +#include <vespa/searchlib/queryeval/global_filter.h> #include <vespa/searchlib/util/fileutil.h> #include <vespa/searchlib/util/state_explorer_utils.h> #include <vespa/vespalib/data/slime/cursor.h> @@ -214,7 +215,7 @@ HnswIndex::calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const } uint32_t -HnswIndex::estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const search::BitVector* filter) const +HnswIndex::estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const GlobalFilter* filter) const { uint32_t m_for_level = max_links_for_level(level); uint64_t base_estimate = uint64_t(m_for_level) * neighbors_to_find + 100; @@ -224,7 +225,7 @@ HnswIndex::estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_ if (!filter) { return base_estimate; } - uint32_t true_bits = filter->countTrueBits(); + uint32_t true_bits = filter->count(); if (true_bits == 0) { return doc_id_limit; } @@ -260,7 +261,7 @@ HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& e template <class VisitedTracker> void HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find, - FurthestPriQ& best_neighbors, uint32_t level, const search::BitVector *filter, + FurthestPriQ& best_neighbors, uint32_t level, const GlobalFilter *filter, uint32_t doc_id_limit, uint32_t estimated_visited_nodes) const { NearestPriQ candidates; @@ -271,7 +272,7 @@ HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_fi } candidates.push(entry); visited.mark(entry.docid); - if (filter && !filter->testBit(entry.docid)) { + if (filter && !filter->check(entry.docid)) { assert(best_neighbors.size() == 1); best_neighbors.pop(); } @@ -297,7 +298,7 @@ HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_fi double dist_to_input = calc_distance(input, neighbor_docid); if (dist_to_input < limit_dist) { candidates.emplace(neighbor_docid, neighbor_ref, dist_to_input); - if ((!filter) || filter->testBit(neighbor_docid)) { + if ((!filter) || filter->check(neighbor_docid)) { best_neighbors.emplace(neighbor_docid, neighbor_ref, dist_to_input); if (best_neighbors.size() > neighbors_to_find) { best_neighbors.pop(); @@ -311,7 +312,7 @@ HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_fi void HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, - FurthestPriQ& best_neighbors, uint32_t level, const search::BitVector *filter) const + FurthestPriQ& best_neighbors, uint32_t level, const GlobalFilter *filter) const { uint32_t doc_id_limit = _graph.node_refs_size.load(std::memory_order_acquire); if (filter) { @@ -698,7 +699,7 @@ struct NeighborsByDocId { std::vector<NearestNeighborIndex::Neighbor> HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, - const BitVector *filter, uint32_t explore_k, + const GlobalFilter *filter, uint32_t explore_k, double distance_threshold) const { std::vector<Neighbor> result; @@ -724,14 +725,14 @@ HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, std::vector<NearestNeighborIndex::Neighbor> HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector, - const BitVector &filter, uint32_t explore_k, + const GlobalFilter &filter, uint32_t explore_k, double distance_threshold) const { return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold); } FurthestPriQ -HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const +HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const { FurthestPriQ best_neighbors; auto entry = _graph.get_entry_node(); diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index 3f5a9d514ed..e3ffada1fc2 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -135,7 +135,7 @@ protected: double calc_distance(uint32_t lhs_docid, uint32_t rhs_docid) const; double calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const; - uint32_t estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const search::BitVector* filter) const; + uint32_t estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const GlobalFilter* filter) const; /** * Performs a greedy search in the given layer to find the candidate that is nearest the input vector. @@ -143,13 +143,13 @@ protected: HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const; template <class VisitedTracker> void search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, - uint32_t level, const search::BitVector *filter, + uint32_t level, const GlobalFilter *filter, uint32_t doc_id_limit, uint32_t estimated_visited_nodes) const; void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, - uint32_t level, const search::BitVector *filter = nullptr) const; + uint32_t level, const GlobalFilter *filter = nullptr) const; std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector, - const BitVector *filter, uint32_t explore_k, + const GlobalFilter *filter, uint32_t explore_k, double distance_threshold) const; struct PreparedFirstAddDoc : public PrepareResult {}; @@ -206,11 +206,11 @@ public: std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, double distance_threshold) const override; std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector, - const BitVector &filter, uint32_t explore_k, + const GlobalFilter &filter, uint32_t explore_k, double distance_threshold) const override; const DistanceFunction *distance_function() const override { return _distance_func.get(); } - FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const; + FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const; uint32_t get_entry_docid() const { return _graph.get_entry_node().docid; } int32_t get_entry_level() const { return _graph.get_entry_node().level; } diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index 530d3e1036d..51d66fdd14d 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -20,10 +20,8 @@ namespace vespalib::slime { struct Inserter; } namespace search::fileutil { class LoadedBuffer; } -namespace search { -class AddressSpaceUsage; -class BitVector; -} +namespace search { class AddressSpaceUsage; } +namespace search::queryeval { class GlobalFilter; } namespace search::tensor { @@ -35,6 +33,7 @@ class NearestNeighborIndexSaver; */ class NearestNeighborIndex { public: + using GlobalFilter = search::queryeval::GlobalFilter; using CompactionSpec = vespalib::datastore::CompactionSpec; using CompactionStrategy = vespalib::datastore::CompactionStrategy; using generation_t = vespalib::GenerationHandler::generation_t; @@ -101,7 +100,7 @@ public: // only return neighbors where the corresponding filter bit is set virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, - const BitVector &filter, + const GlobalFilter &filter, uint32_t explore_k, double distance_threshold) const = 0; |