diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2022-09-12 14:14:38 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2022-09-13 12:08:31 +0000 |
commit | 295e925a71197f4a10bc31dc25add77fc9a2ec78 (patch) | |
tree | 077da7b222b727fdcd8efba97293dabf741aa7d2 | |
parent | be02b6f3d580b78c6d5c36428b302649fb7f0717 (diff) |
add support for multi-bitvector global filter
move some testing convenience into GlobalFilter
make more realistic bitvectors for filter testing
8 files changed, 227 insertions, 24 deletions
diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 4d0f520f666..80a02b0d928 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -89,9 +89,9 @@ vespa_define_module( src/tests/attribute/multi_value_mapping src/tests/attribute/multi_value_read_view src/tests/attribute/posting_list_merger + src/tests/attribute/posting_store src/tests/attribute/postinglist src/tests/attribute/postinglistattribute - src/tests/attribute/posting_store src/tests/attribute/reference_attribute src/tests/attribute/save_target src/tests/attribute/searchable @@ -112,8 +112,8 @@ vespa_define_module( src/tests/common/summaryfeatures src/tests/diskindex/bitvector src/tests/diskindex/diskindex - src/tests/diskindex/fieldwriter src/tests/diskindex/field_length_scanner + src/tests/diskindex/fieldwriter src/tests/diskindex/fusion src/tests/diskindex/pagedict4 src/tests/docstore/chunk @@ -193,6 +193,7 @@ vespa_define_module( src/tests/queryeval/equiv src/tests/queryeval/fake_searchable src/tests/queryeval/getnodeweight + src/tests/queryeval/global_filter src/tests/queryeval/matching_elements_search src/tests/queryeval/monitoring_search_iterator src/tests/queryeval/multibitvectoriterator diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index ebe96035fc8..3dda2eb6d95 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -1114,7 +1114,7 @@ TEST_F("NN blueprint handles empty filter (post-filtering)", NearestNeighborBlue TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(3); filter->invalidateCachedCount(); auto strong_filter = GlobalFilter::create(std::move(filter)); @@ -1128,7 +1128,7 @@ TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlue TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(1); filter->setBit(3); filter->setBit(5); @@ -1147,7 +1147,7 @@ TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBluepr TEST_F("NN blueprint handles strong filter triggering exact search", NearestNeighborBlueprintFixture) { auto bp = f.make_blueprint(true, 0.2); - auto filter = search::BitVector::create(11); + auto filter = search::BitVector::create(1,11); filter->setBit(3); filter->invalidateCachedCount(); auto strong_filter = GlobalFilter::create(std::move(filter)); diff --git a/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt new file mode 100644 index 00000000000..2f768bf9d88 --- /dev/null +++ b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_queryeval_global_filter_test_app TEST + SOURCES + global_filter_test.cpp + DEPENDS + searchlib + GTest::GTest +) +vespa_add_test(NAME searchlib_queryeval_global_filter_test_app COMMAND searchlib_queryeval_global_filter_test_app) diff --git a/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp new file mode 100644 index 00000000000..a64e75cecd6 --- /dev/null +++ b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp @@ -0,0 +1,139 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/require.h> +#include <vespa/searchlib/queryeval/global_filter.h> +#include <vespa/searchlib/common/bitvector.h> + +#include <gmock/gmock.h> +#include <vector> + +using namespace testing; + +using search::BitVector; +using search::queryeval::GlobalFilter; +using vespalib::RequireFailedException; + +TEST(GlobalFilterTest, create_can_make_inactive_filter) { + auto filter = GlobalFilter::create(); + EXPECT_FALSE(filter->is_active()); +} + +void verify(const GlobalFilter &filter) { + EXPECT_TRUE(filter.is_active()); + EXPECT_EQ(filter.size(), 100); + EXPECT_EQ(filter.count(), 3); + for (size_t i = 1; i < 100; ++i) { + if (i == 11 || i == 22 || i == 33) { + EXPECT_TRUE(filter.check(i)); + } else { + EXPECT_FALSE(filter.check(i)); + } + } +} + +TEST(GlobalFilterTest, create_can_make_test_filter) { + auto docs = std::vector<uint32_t>({11,22,33}); + auto filter = GlobalFilter::create(docs, 100); + verify(*filter); +} + +TEST(GlobalFilterTest, test_filter_requires_docs_in_order) { + auto docs = std::vector<uint32_t>({11,33,22}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, test_filter_requires_docs_in_range) { + auto docs = std::vector<uint32_t>({11,22,133}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, test_filter_docid_0_not_allowed) { + auto docs = std::vector<uint32_t>({0,22,33}); + EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, create_can_make_single_bitvector_filter) { + auto bits = BitVector::create(1, 100); + bits->setBit(11); + bits->setBit(22); + bits->setBit(33); + bits->invalidateCachedCount(); + EXPECT_EQ(bits->countTrueBits(), 3); + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, global_filter_pointer_guard) { + auto inactive = GlobalFilter::create(); + auto active = GlobalFilter::create(BitVector::create(1,100)); + EXPECT_TRUE(active->is_active()); + EXPECT_FALSE(inactive->is_active()); + EXPECT_TRUE(active->ptr_if_active() == active.get()); + EXPECT_TRUE(inactive->ptr_if_active() == nullptr); +} + +TEST(GlobalFilterTest, create_can_make_multi_bitvector_filter) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(11, 23)); + bits.push_back(BitVector::create(23, 25)); + bits.push_back(BitVector::create(25, 100)); + bits[1]->setBit(11); + bits[1]->setBit(22); + bits[3]->setBit(33); + for (const auto &v: bits) { + v->invalidateCachedCount(); + } + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_with_empty_vectors) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(11, 23)); + bits.push_back(BitVector::create(23, 23)); + bits.push_back(BitVector::create(23, 23)); + bits.push_back(BitVector::create(23, 25)); + bits.push_back(BitVector::create(25, 100)); + bits[1]->setBit(11); + bits[1]->setBit(22); + bits[5]->setBit(33); + for (const auto &v: bits) { + v->invalidateCachedCount(); + } + auto filter = GlobalFilter::create(std::move(bits)); + verify(*filter); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_with_no_vectors) { + std::vector<std::unique_ptr<BitVector>> bits; + auto filter = GlobalFilter::create(std::move(bits)); + EXPECT_TRUE(filter->is_active()); + EXPECT_EQ(filter->size(), 0); + EXPECT_EQ(filter->count(), 0); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_gaps) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(12, 100)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_overlap) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(1, 11)); + bits.push_back(BitVector::create(10, 100)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>()); +} + +TEST(GlobalFilterTest, multi_bitvector_filter_requires_correct_order) { + std::vector<std::unique_ptr<BitVector>> bits; + bits.push_back(BitVector::create(11, 100)); + bits.push_back(BitVector::create(1, 11)); + EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 33435f43618..f02681908d6 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -96,12 +96,7 @@ struct Fixture void setFilter(std::vector<uint32_t> docids) { uint32_t sz = _attr->getNumDocs(); - auto bit_vector = BitVector::create(sz); - for (uint32_t id : docids) { - EXPECT_LT(id, sz); - bit_vector->setBit(id); - } - _global_filter = GlobalFilter::create(std::move(bit_vector)); + _global_filter = GlobalFilter::create(docids, sz); } void setTensor(uint32_t docId, const Value &tensor) { diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 193bb04843c..7877b488065 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -82,10 +82,6 @@ public: ~HnswIndexTest() {} - const GlobalFilter *global_filter_ptr() const { - return global_filter->is_active() ? global_filter.get() : nullptr; - } - void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); @@ -110,12 +106,7 @@ public: } void set_filter(std::vector<uint32_t> docids) { uint32_t sz = 10; - auto bit_vector = BitVector::create(sz); - for (uint32_t id : docids) { - EXPECT_LT(id, sz); - bit_vector->setBit(id); - } - global_filter = GlobalFilter::create(std::move(bit_vector)); + global_filter = GlobalFilter::create(docids, sz); } GenerationHandler::Guard take_read_guard() { return gen_handler.takeGuard(); @@ -149,7 +140,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -170,7 +161,7 @@ public: void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp index 1a5d3d3dacd..2aff91974a7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "global_filter.h" +#include <vespa/vespalib/util/require.h> namespace search::queryeval { @@ -23,6 +24,31 @@ struct BitVectorFilter : public GlobalFilter { bool check(uint32_t docid) const override { return vector->testBit(docid); } }; +struct MultiBitVectorFilter : public GlobalFilter { + std::vector<std::unique_ptr<BitVector>> vectors; + std::vector<uint32_t> splits; + uint32_t total_size; + uint32_t total_count; + MultiBitVectorFilter(std::vector<std::unique_ptr<BitVector>> vectors_in, + std::vector<uint32_t> splits_in, + uint32_t total_size_in, + uint32_t total_count_in) + : vectors(std::move(vectors_in)), + splits(std::move(splits_in)), + total_size(total_size_in), + total_count(total_count_in) {} + bool is_active() const override { return true; } + uint32_t size() const override { return total_size; } + uint32_t count() const override { return total_count; } + bool check(uint32_t docid) const override { + size_t i = 0; + while ((i < splits.size()) && (docid >= splits[i])) { + ++i; + } + return vectors[i]->testBit(docid); + } +}; + } GlobalFilter::GlobalFilter() = default; @@ -34,8 +60,44 @@ GlobalFilter::create() { } std::shared_ptr<GlobalFilter> -GlobalFilter::create(std::unique_ptr<BitVector> vector) { +GlobalFilter::create(std::vector<uint32_t> docids, uint32_t size) +{ + uint32_t prev = 0; + auto bits = BitVector::create(1, size); + for (uint32_t docid: docids) { + REQUIRE(docid > prev); + REQUIRE(docid < size); + bits->setBit(docid); + prev = docid; + } + bits->invalidateCachedCount(); + return create(std::move(bits)); +} + +std::shared_ptr<GlobalFilter> +GlobalFilter::create(std::unique_ptr<BitVector> vector) +{ return std::make_shared<BitVectorFilter>(std::move(vector)); } +std::shared_ptr<GlobalFilter> +GlobalFilter::create(std::vector<std::unique_ptr<BitVector>> vectors) +{ + uint32_t total_size = 0; + uint32_t total_count = 0; + std::vector<uint32_t> splits; + for (size_t i = 0; i < vectors.size(); ++i) { + bool last = ((i + 1) == vectors.size()); + total_count += vectors[i]->countTrueBits(); + if (last) { + total_size = vectors[i]->size(); + } else { + REQUIRE_EQ(vectors[i]->size(), vectors[i + 1]->getStartIndex()); + splits.push_back(vectors[i]->size()); + } + } + return std::make_shared<MultiBitVectorFilter>(std::move(vectors), std::move(splits), + total_size, total_count); +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h index c6e08d5018d..8504367e5b7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/global_filter.h +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h @@ -26,8 +26,14 @@ public: virtual bool check(uint32_t docid) const = 0; virtual ~GlobalFilter(); + const GlobalFilter *ptr_if_active() const { + return is_active() ? this : nullptr; + } + static std::shared_ptr<GlobalFilter> create(); + static std::shared_ptr<GlobalFilter> create(std::vector<uint32_t> docids, uint32_t size); static std::shared_ptr<GlobalFilter> create(std::unique_ptr<BitVector> vector); + static std::shared_ptr<GlobalFilter> create(std::vector<std::unique_ptr<BitVector>> vectors); }; } // namespace |