From 295e925a71197f4a10bc31dc25add77fc9a2ec78 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Mon, 12 Sep 2022 14:14:38 +0000 Subject: add support for multi-bitvector global filter move some testing convenience into GlobalFilter make more realistic bitvectors for filter testing --- searchlib/CMakeLists.txt | 5 +- .../tensorattribute/tensorattribute_test.cpp | 6 +- .../tests/queryeval/global_filter/CMakeLists.txt | 9 ++ .../queryeval/global_filter/global_filter_test.cpp | 139 +++++++++++++++++++++ .../nearest_neighbor/nearest_neighbor_test.cpp | 7 +- .../tests/tensor/hnsw_index/hnsw_index_test.cpp | 15 +-- .../vespa/searchlib/queryeval/global_filter.cpp | 64 +++++++++- .../src/vespa/searchlib/queryeval/global_filter.h | 6 + 8 files changed, 227 insertions(+), 24 deletions(-) create mode 100644 searchlib/src/tests/queryeval/global_filter/CMakeLists.txt create mode 100644 searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 4d0f520f666..80a02b0d928 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -89,9 +89,9 @@ vespa_define_module( src/tests/attribute/multi_value_mapping src/tests/attribute/multi_value_read_view src/tests/attribute/posting_list_merger + src/tests/attribute/posting_store src/tests/attribute/postinglist src/tests/attribute/postinglistattribute - src/tests/attribute/posting_store src/tests/attribute/reference_attribute src/tests/attribute/save_target src/tests/attribute/searchable @@ -112,8 +112,8 @@ vespa_define_module( src/tests/common/summaryfeatures src/tests/diskindex/bitvector src/tests/diskindex/diskindex - src/tests/diskindex/fieldwriter src/tests/diskindex/field_length_scanner + src/tests/diskindex/fieldwriter src/tests/diskindex/fusion src/tests/diskindex/pagedict4 src/tests/docstore/chunk @@ -193,6 +193,7 @@ vespa_define_module( src/tests/queryeval/equiv src/tests/queryeval/fake_searchable src/tests/queryeval/getnodeweight + src/tests/queryeval/global_filter src/tests/queryeval/matching_elements_search src/tests/queryeval/monitoring_search_iterator src/tests/queryeval/multibitvectoriterator diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index ebe96035fc8..3dda2eb6d95 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -1114,7 +1114,7 @@ TEST_F("NN blueprint handles empty filter (post-filtering)", NearestNeighborBlue TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(3); filter->invalidateCachedCount(); auto strong_filter = GlobalFilter::create(std::move(filter)); @@ -1128,7 +1128,7 @@ TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlue TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(1); filter->setBit(3); filter->setBit(5); @@ -1147,7 +1147,7 @@ TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBluepr TEST_F("NN blueprint handles strong filter triggering exact search", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(true, 0.2); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(3); filter->invalidateCachedCount(); auto strong_filter = GlobalFilter::create(std::move(filter)); diff --git a/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt new file mode 100644 index 00000000000..2f768bf9d88 --- /dev/null +++ b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_queryeval_global_filter_test_app TEST + SOURCES + global_filter_test.cpp + DEPENDS + searchlib + GTest::GTest +) +vespa_add_test(NAME searchlib_queryeval_global_filter_test_app COMMAND searchlib_queryeval_global_filter_test_app) diff --git a/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp new file mode 100644 index 00000000000..a64e75cecd6 --- /dev/null +++ b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp @@ -0,0 +1,139 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include + +#include +#include + +using namespace testing; + +using search::BitVector; +using search::queryeval::GlobalFilter; +using vespalib::RequireFailedException; + +TEST(GlobalFilterTest, create_can_make_inactive_filter) { + auto filter = GlobalFilter::create(); + EXPECT_FALSE(filter->is_active()); +} + +void verify(const GlobalFilter &filter) { + EXPECT_TRUE(filter.is_active()); + EXPECT_EQ(filter.size(), 100); + EXPECT_EQ(filter.count(), 3); + for (size_t i = 1; i < 100; ++i) { + if (i == 11 || i == 22 || i == 33) { + EXPECT_TRUE(filter.check(i)); + } else { + EXPECT_FALSE(filter.check(i)); + } + } +} + +TEST(GlobalFilterTest, create_can_make_test_filter) { + auto docs = std::vector({11,22,33}); + auto filter = GlobalFilter::create(docs, 100); + verify(*filter); +} + +TEST(GlobalFilterTest, test_filter_requires_docs_in_order) { + auto docs = std::vector({11,33,22}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws()); +} + +TEST(GlobalFilterTest, test_filter_requires_docs_in_range) { + auto docs = std::vector({11,22,133}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws()); +} + +TEST(GlobalFilterTest, test_filter_docid_0_not_allowed) { + auto docs = std::vector({0,22,33}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws()); +} + +TEST(GlobalFilterTest, create_can_make_single_bitvector_filter) { + auto bits = BitVector::create(1, 100); + bits->setBit(11); + bits->setBit(22); + bits->setBit(33); + bits->invalidateCachedCount(); + EXPECT_EQ(bits->countTrueBits(), 3); + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, global_filter_pointer_guard) { + auto inactive = GlobalFilter::create(); + auto active = GlobalFilter::create(BitVector::create(1,100)); + EXPECT_TRUE(active->is_active()); + EXPECT_FALSE(inactive->is_active()); + EXPECT_TRUE(active->ptr_if_active() == active.get()); + EXPECT_TRUE(inactive->ptr_if_active() == nullptr); +} + +TEST(GlobalFilterTest, create_can_make_multi_bitvector_filter) { + std::vector> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(11, 23)); + bits.push_back(BitVector::create(23, 25)); + bits.push_back(BitVector::create(25, 100)); + bits[1]->setBit(11); + bits[1]->setBit(22); + bits[3]->setBit(33); + for (const auto &v: bits) { + v->invalidateCachedCount(); + } + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_with_empty_vectors) { + std::vector> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(11, 23)); + bits.push_back(BitVector::create(23, 23)); + bits.push_back(BitVector::create(23, 23)); + bits.push_back(BitVector::create(23, 25)); + bits.push_back(BitVector::create(25, 100)); + bits[1]->setBit(11); + bits[1]->setBit(22); + bits[5]->setBit(33); + for (const auto &v: bits) { + v->invalidateCachedCount(); + } + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_with_no_vectors) { + std::vector> bits; + auto filter = GlobalFilter::create(std::move(bits)); + EXPECT_TRUE(filter->is_active()); + EXPECT_EQ(filter->size(), 0); + EXPECT_EQ(filter->count(), 0); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_gaps) { + std::vector> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(12, 100)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws()); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_overlap) { + std::vector> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(10, 100)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws()); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_correct_order) { + std::vector> bits; + bits.push_back(BitVector::create(11, 100)); + bits.push_back(BitVector::create(1, 11)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 33435f43618..f02681908d6 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -96,12 +96,7 @@ struct Fixture void setFilter(std::vector docids) { uint32_t sz = _attr->getNumDocs(); - auto bit_vector = BitVector::create(sz); - for (uint32_t id : docids) { - EXPECT_LT(id, sz); - bit_vector->setBit(id); - } - _global_filter = GlobalFilter::create(std::move(bit_vector)); + _global_filter = GlobalFilter::create(docids, sz); } void setTensor(uint32_t docId, const Value &tensor) { diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 193bb04843c..7877b488065 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -82,10 +82,6 @@ public: ~HnswIndexTest() {} - const GlobalFilter *global_filter_ptr() const { - return global_filter->is_active() ? global_filter.get() : nullptr; - } - void init(bool heuristic_select_neighbors) { auto generator = std::make_unique(); level_generator = generator.get(); @@ -110,12 +106,7 @@ public: } void set_filter(std::vector docids) { uint32_t sz = 10; - auto bit_vector = BitVector::create(sz); - for (uint32_t id : docids) { - EXPECT_LT(id, sz); - bit_vector->setBit(id); - } - global_filter = GlobalFilter::create(std::move(bit_vector)); + global_filter = GlobalFilter::create(docids, sz); } GenerationHandler::Guard take_read_guard() { return gen_handler.takeGuard(); @@ -149,7 +140,7 @@ public: void expect_top_3(uint32_t docid, std::vector exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -170,7 +161,7 @@ public: void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp index 1a5d3d3dacd..2aff91974a7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "global_filter.h" +#include namespace search::queryeval { @@ -23,6 +24,31 @@ struct BitVectorFilter : public GlobalFilter { bool check(uint32_t docid) const override { return vector->testBit(docid); } }; +struct MultiBitVectorFilter : public GlobalFilter { + std::vector> vectors; + std::vector splits; + uint32_t total_size; + uint32_t total_count; + MultiBitVectorFilter(std::vector> vectors_in, + std::vector splits_in, + uint32_t total_size_in, + uint32_t total_count_in) + : vectors(std::move(vectors_in)), + splits(std::move(splits_in)), + total_size(total_size_in), + total_count(total_count_in) {} + bool is_active() const override { return true; } + uint32_t size() const override { return total_size; } + uint32_t count() const override { return total_count; } + bool check(uint32_t docid) const override { + size_t i = 0; + while ((i < splits.size()) && (docid >= splits[i])) { + ++i; + } + return vectors[i]->testBit(docid); + } +}; + } GlobalFilter::GlobalFilter() = default; @@ -34,8 +60,44 @@ GlobalFilter::create() { } std::shared_ptr -GlobalFilter::create(std::unique_ptr vector) { +GlobalFilter::create(std::vector docids, uint32_t size) +{ + uint32_t prev = 0; + auto bits = BitVector::create(1, size); + for (uint32_t docid: docids) { + REQUIRE(docid > prev); + REQUIRE(docid < size); + bits->setBit(docid); + prev = docid; + } + bits->invalidateCachedCount(); + return create(std::move(bits)); +} + +std::shared_ptr +GlobalFilter::create(std::unique_ptr vector) +{ return std::make_shared(std::move(vector)); } +std::shared_ptr +GlobalFilter::create(std::vector> vectors) +{ + uint32_t total_size = 0; + uint32_t total_count = 0; + std::vector splits; + for (size_t i = 0; i < vectors.size(); ++i) { + bool last = ((i + 1) == vectors.size()); + total_count += vectors[i]->countTrueBits(); + if (last) { + total_size = vectors[i]->size(); + } else { + REQUIRE_EQ(vectors[i]->size(), vectors[i + 1]->getStartIndex()); + splits.push_back(vectors[i]->size()); + } + } + return std::make_shared(std::move(vectors), std::move(splits), + total_size, total_count); +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h index c6e08d5018d..8504367e5b7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/global_filter.h +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h @@ -26,8 +26,14 @@ public: virtual bool check(uint32_t docid) const = 0; virtual ~GlobalFilter(); + const GlobalFilter *ptr_if_active() const { + return is_active() ? this : nullptr; + } + static std::shared_ptr create(); + static std::shared_ptr create(std::vector docids, uint32_t size); static std::shared_ptr create(std::unique_ptr vector); + static std::shared_ptr create(std::vector> vectors); }; } // namespace -- cgit v1.2.3