diff options
Diffstat (limited to 'searchlib/src/tests')
4 files changed, 25 insertions, 15 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index b93398e16a1..ebe96035fc8 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -260,7 +260,7 @@ public: return std::vector<Neighbor>(); } std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, - const search::BitVector& filter, uint32_t explore_k, + const GlobalFilter& filter, uint32_t explore_k, double distance_threshold) const override { (void) k; diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index 87de62dbfad..2379213b87b 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -80,7 +80,7 @@ TEST("test AndNot Blueprint") { a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create())); EXPECT_EQUAL(true, a.getState().want_global_filter()); auto empty_global_filter = GlobalFilter::create(); - EXPECT_FALSE(empty_global_filter->has_filter()); + EXPECT_FALSE(empty_global_filter->is_active()); a.set_global_filter(*empty_global_filter, 1.0); EXPECT_EQUAL(false, got_global_filter(a.getChild(0))); EXPECT_EQUAL(true, got_global_filter(a.getChild(1))); diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 1e341eab707..33435f43618 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -9,6 +9,7 @@ #include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h> #include <vespa/searchlib/queryeval/nns_index_iterator.h> #include <vespa/searchlib/queryeval/simpleresult.h> +#include <vespa/searchlib/queryeval/global_filter.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/distance_function_factory.h> @@ -63,7 +64,7 @@ struct Fixture vespalib::string _typeSpec; std::shared_ptr<DenseTensorAttribute> _tensorAttr; std::shared_ptr<AttributeVector> _attr; - std::unique_ptr<BitVector> _global_filter; + std::shared_ptr<GlobalFilter> _global_filter; Fixture(const vespalib::string &typeSpec) : _cfg(BasicType::TENSOR, CollectionType::SINGLE), @@ -71,7 +72,7 @@ struct Fixture _typeSpec(typeSpec), _tensorAttr(), _attr(), - _global_filter() + _global_filter(GlobalFilter::create()) { _cfg.setTensorType(ValueType::from_spec(typeSpec)); _tensorAttr = makeAttr(); @@ -95,11 +96,12 @@ struct Fixture void setFilter(std::vector<uint32_t> docids) { uint32_t sz = _attr->getNumDocs(); - _global_filter = BitVector::create(sz); + auto bit_vector = BitVector::create(sz); for (uint32_t id : docids) { EXPECT_LT(id, sz); - _global_filter->setBit(id); + bit_vector->setBit(id); } + _global_filter = GlobalFilter::create(std::move(bit_vector)); } void setTensor(uint32_t docId, const Value &tensor) { @@ -130,7 +132,7 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); - const BitVector *filter = env._global_filter.get(); + const GlobalFilter &filter = *env._global_filter; auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); @@ -222,7 +224,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto &attr = *(env._tensorAttr); DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr); + auto dummy_filter = GlobalFilter::create(); + auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 3d1127e6bc4..193bb04843c 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/tensor/hnsw_index.h> #include <vespa/searchlib/tensor/random_level_generator.h> #include <vespa/searchlib/tensor/inv_log_level_generator.h> +#include <vespa/searchlib/queryeval/global_filter.h> #include <vespa/vespalib/datastore/compaction_spec.h> #include <vespa/vespalib/datastore/compaction_strategy.h> #include <vespa/vespalib/gtest/gtest.h> @@ -24,6 +25,7 @@ using vespalib::Slime; using search::BitVector; using vespalib::datastore::CompactionSpec; using vespalib::datastore::CompactionStrategy; +using search::queryeval::GlobalFilter; template <typename FloatType> class MyDocVectorAccess : public DocVectorAccess { @@ -61,14 +63,14 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>; class HnswIndexTest : public ::testing::Test { public: FloatVectors vectors; - std::unique_ptr<BitVector> global_filter; + std::shared_ptr<GlobalFilter> global_filter; LevelGenerator* level_generator; GenerationHandler gen_handler; HnswIndexUP index; HnswIndexTest() : vectors(), - global_filter(), + global_filter(GlobalFilter::create()), level_generator(), gen_handler(), index() @@ -80,6 +82,10 @@ public: ~HnswIndexTest() {} + const GlobalFilter *global_filter_ptr() const { + return global_filter->is_active() ? global_filter.get() : nullptr; + } + void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); @@ -104,11 +110,12 @@ public: } void set_filter(std::vector<uint32_t> docids) { uint32_t sz = 10; - global_filter = BitVector::create(sz); + auto bit_vector = BitVector::create(sz); for (uint32_t id : docids) { EXPECT_LT(id, sz); - global_filter->setBit(id); + bit_vector->setBit(id); } + global_filter = GlobalFilter::create(std::move(bit_vector)); } GenerationHandler::Guard take_read_guard() { return gen_handler.takeGuard(); @@ -142,7 +149,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -163,12 +170,12 @@ public: void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); double thr = (rv[0].distance + rv[1].distance) * 0.5; - auto got_by_docid = (global_filter) + auto got_by_docid = (global_filter->is_active()) ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr) : index->find_top_k(k, qv, k, thr); EXPECT_EQ(got_by_docid.size(), 1); |