diff options
Diffstat (limited to 'searchlib/src/tests')
5 files changed, 155 insertions, 21 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index ebe96035fc8..3dda2eb6d95 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -1114,7 +1114,7 @@ TEST_F("NN blueprint handles empty filter (post-filtering)", NearestNeighborBlue TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(3); filter->invalidateCachedCount(); auto strong_filter = GlobalFilter::create(std::move(filter)); @@ -1128,7 +1128,7 @@ TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlue TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(1); filter->setBit(3); filter->setBit(5); @@ -1147,7 +1147,7 @@ TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBluepr TEST_F("NN blueprint handles strong filter triggering exact search", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(true, 0.2); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(3); filter->invalidateCachedCount(); auto strong_filter = GlobalFilter::create(std::move(filter)); diff --git a/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt new file mode 100644 index 00000000000..2f768bf9d88 --- /dev/null +++ b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_queryeval_global_filter_test_app TEST + SOURCES + global_filter_test.cpp + DEPENDS + searchlib + GTest::GTest +) +vespa_add_test(NAME searchlib_queryeval_global_filter_test_app COMMAND searchlib_queryeval_global_filter_test_app) diff --git a/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp new file mode 100644 index 00000000000..a64e75cecd6 --- /dev/null +++ b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp @@ -0,0 +1,139 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/require.h> +#include <vespa/searchlib/queryeval/global_filter.h> +#include <vespa/searchlib/common/bitvector.h> + +#include <gmock/gmock.h> +#include <vector> + +using namespace testing; + +using search::BitVector; +using search::queryeval::GlobalFilter; +using vespalib::RequireFailedException; + +TEST(GlobalFilterTest, create_can_make_inactive_filter) { + auto filter = GlobalFilter::create(); + EXPECT_FALSE(filter->is_active()); +} + +void verify(const GlobalFilter &filter) { + EXPECT_TRUE(filter.is_active()); + EXPECT_EQ(filter.size(), 100); + EXPECT_EQ(filter.count(), 3); + for (size_t i = 1; i < 100; ++i) { + if (i == 11 || i == 22 || i == 33) { + EXPECT_TRUE(filter.check(i)); + } else { + EXPECT_FALSE(filter.check(i)); + } + } +} + +TEST(GlobalFilterTest, create_can_make_test_filter) { + auto docs = std::vector<uint32_t>({11,22,33}); + auto filter = GlobalFilter::create(docs, 100); + verify(*filter); +} + +TEST(GlobalFilterTest, test_filter_requires_docs_in_order) { + auto docs = std::vector<uint32_t>({11,33,22}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, test_filter_requires_docs_in_range) { + auto docs = std::vector<uint32_t>({11,22,133}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, test_filter_docid_0_not_allowed) { + auto docs = std::vector<uint32_t>({0,22,33}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, create_can_make_single_bitvector_filter) { + auto bits = BitVector::create(1, 100); + bits->setBit(11); + bits->setBit(22); + bits->setBit(33); + bits->invalidateCachedCount(); + EXPECT_EQ(bits->countTrueBits(), 3); + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, global_filter_pointer_guard) { + auto inactive = GlobalFilter::create(); + auto active = GlobalFilter::create(BitVector::create(1,100)); + EXPECT_TRUE(active->is_active()); + EXPECT_FALSE(inactive->is_active()); + EXPECT_TRUE(active->ptr_if_active() == active.get()); + EXPECT_TRUE(inactive->ptr_if_active() == nullptr); +} + +TEST(GlobalFilterTest, create_can_make_multi_bitvector_filter) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(11, 23)); + bits.push_back(BitVector::create(23, 25)); + bits.push_back(BitVector::create(25, 100)); + bits[1]->setBit(11); + bits[1]->setBit(22); + bits[3]->setBit(33); + for (const auto &v: bits) { + v->invalidateCachedCount(); + } + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_with_empty_vectors) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(11, 23)); + bits.push_back(BitVector::create(23, 23)); + bits.push_back(BitVector::create(23, 23)); + bits.push_back(BitVector::create(23, 25)); + bits.push_back(BitVector::create(25, 100)); + bits[1]->setBit(11); + bits[1]->setBit(22); + bits[5]->setBit(33); + for (const auto &v: bits) { + v->invalidateCachedCount(); + } + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_with_no_vectors) { + std::vector<std::unique_ptr<BitVector>> bits; + auto filter = GlobalFilter::create(std::move(bits)); + EXPECT_TRUE(filter->is_active()); + EXPECT_EQ(filter->size(), 0); + EXPECT_EQ(filter->count(), 0); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_gaps) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(12, 100)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_overlap) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(10, 100)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_correct_order) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(11, 100)); + bits.push_back(BitVector::create(1, 11)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 33435f43618..f02681908d6 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -96,12 +96,7 @@ struct Fixture void setFilter(std::vector<uint32_t> docids) { uint32_t sz = _attr->getNumDocs(); - auto bit_vector = BitVector::create(sz); - for (uint32_t id : docids) { - EXPECT_LT(id, sz); - bit_vector->setBit(id); - } - _global_filter = GlobalFilter::create(std::move(bit_vector)); + _global_filter = GlobalFilter::create(docids, sz); } void setTensor(uint32_t docId, const Value &tensor) { diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 193bb04843c..7877b488065 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -82,10 +82,6 @@ public: ~HnswIndexTest() {} - const GlobalFilter *global_filter_ptr() const { - return global_filter->is_active() ? global_filter.get() : nullptr; - } - void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); @@ -110,12 +106,7 @@ public: } void set_filter(std::vector<uint32_t> docids) { uint32_t sz = 10; - auto bit_vector = BitVector::create(sz); - for (uint32_t id : docids) { - EXPECT_LT(id, sz); - bit_vector->setBit(id); - } - global_filter = GlobalFilter::create(std::move(bit_vector)); + global_filter = GlobalFilter::create(docids, sz); } GenerationHandler::Guard take_read_guard() { return gen_handler.takeGuard(); @@ -149,7 +140,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -170,7 +161,7 @@ public: void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); |