aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-09-13 17:45:31 +0200
committerGitHub <noreply@github.com>2022-09-13 17:45:31 +0200
commit04c3414342c1cc296f8a56d4112f77b1a463cc70 (patch)
tree1cfed55d6bd46843a0f43287cd73cc86f4184d09
parent2f325645ed04b978b84d2949ec860f2c0d722c58 (diff)
parent295e925a71197f4a10bc31dc25add77fc9a2ec78 (diff)
Merge pull request #24036 from vespa-engine/havardpe/multi-bitvector-global-filterv8.52.15
add support for multi-bitvector global filter
-rw-r--r--searchlib/CMakeLists.txt5
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp6
-rw-r--r--searchlib/src/tests/queryeval/global_filter/CMakeLists.txt9
-rw-r--r--searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp139
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp7
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.cpp64
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.h6
8 files changed, 227 insertions, 24 deletions
diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt
index 4d0f520f666..80a02b0d928 100644
--- a/searchlib/CMakeLists.txt
+++ b/searchlib/CMakeLists.txt
@@ -89,9 +89,9 @@ vespa_define_module(
src/tests/attribute/multi_value_mapping
src/tests/attribute/multi_value_read_view
src/tests/attribute/posting_list_merger
+ src/tests/attribute/posting_store
src/tests/attribute/postinglist
src/tests/attribute/postinglistattribute
- src/tests/attribute/posting_store
src/tests/attribute/reference_attribute
src/tests/attribute/save_target
src/tests/attribute/searchable
@@ -112,8 +112,8 @@ vespa_define_module(
src/tests/common/summaryfeatures
src/tests/diskindex/bitvector
src/tests/diskindex/diskindex
- src/tests/diskindex/fieldwriter
src/tests/diskindex/field_length_scanner
+ src/tests/diskindex/fieldwriter
src/tests/diskindex/fusion
src/tests/diskindex/pagedict4
src/tests/docstore/chunk
@@ -193,6 +193,7 @@ vespa_define_module(
src/tests/queryeval/equiv
src/tests/queryeval/fake_searchable
src/tests/queryeval/getnodeweight
+ src/tests/queryeval/global_filter
src/tests/queryeval/matching_elements_search
src/tests/queryeval/monitoring_search_iterator
src/tests/queryeval/multibitvectoriterator
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index ebe96035fc8..3dda2eb6d95 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -1114,7 +1114,7 @@ TEST_F("NN blueprint handles empty filter (post-filtering)", NearestNeighborBlue
TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlueprintFixture)
{
auto bp = f.make_blueprint();
- auto filter = search::BitVector::create(11);
+ auto filter = search::BitVector::create(1,11);
filter->setBit(3);
filter->invalidateCachedCount();
auto strong_filter = GlobalFilter::create(std::move(filter));
@@ -1128,7 +1128,7 @@ TEST_F("NN blueprint handles strong filter (pre-filtering)", NearestNeighborBlue
TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBlueprintFixture)
{
auto bp = f.make_blueprint();
- auto filter = search::BitVector::create(11);
+ auto filter = search::BitVector::create(1,11);
filter->setBit(1);
filter->setBit(3);
filter->setBit(5);
@@ -1147,7 +1147,7 @@ TEST_F("NN blueprint handles weak filter (pre-filtering)", NearestNeighborBluepr
TEST_F("NN blueprint handles strong filter triggering exact search", NearestNeighborBlueprintFixture)
{
auto bp = f.make_blueprint(true, 0.2);
- auto filter = search::BitVector::create(11);
+ auto filter = search::BitVector::create(1,11);
filter->setBit(3);
filter->invalidateCachedCount();
auto strong_filter = GlobalFilter::create(std::move(filter));
diff --git a/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt
new file mode 100644
index 00000000000..2f768bf9d88
--- /dev/null
+++ b/searchlib/src/tests/queryeval/global_filter/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(searchlib_queryeval_global_filter_test_app TEST
+ SOURCES
+ global_filter_test.cpp
+ DEPENDS
+ searchlib
+ GTest::GTest
+)
+vespa_add_test(NAME searchlib_queryeval_global_filter_test_app COMMAND searchlib_queryeval_global_filter_test_app)
diff --git a/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp
new file mode 100644
index 00000000000..a64e75cecd6
--- /dev/null
+++ b/searchlib/src/tests/queryeval/global_filter/global_filter_test.cpp
@@ -0,0 +1,139 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/require.h>
+#include <vespa/searchlib/queryeval/global_filter.h>
+#include <vespa/searchlib/common/bitvector.h>
+
+#include <gmock/gmock.h>
+#include <vector>
+
+using namespace testing;
+
+using search::BitVector;
+using search::queryeval::GlobalFilter;
+using vespalib::RequireFailedException;
+
+TEST(GlobalFilterTest, create_can_make_inactive_filter) {
+ auto filter = GlobalFilter::create();
+ EXPECT_FALSE(filter->is_active());
+}
+
+void verify(const GlobalFilter &filter) {
+ EXPECT_TRUE(filter.is_active());
+ EXPECT_EQ(filter.size(), 100);
+ EXPECT_EQ(filter.count(), 3);
+ for (size_t i = 1; i < 100; ++i) {
+ if (i == 11 || i == 22 || i == 33) {
+ EXPECT_TRUE(filter.check(i));
+ } else {
+ EXPECT_FALSE(filter.check(i));
+ }
+ }
+}
+
+TEST(GlobalFilterTest, create_can_make_test_filter) {
+ auto docs = std::vector<uint32_t>({11,22,33});
+ auto filter = GlobalFilter::create(docs, 100);
+ verify(*filter);
+}
+
+TEST(GlobalFilterTest, test_filter_requires_docs_in_order) {
+ auto docs = std::vector<uint32_t>({11,33,22});
+ EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>());
+}
+
+TEST(GlobalFilterTest, test_filter_requires_docs_in_range) {
+ auto docs = std::vector<uint32_t>({11,22,133});
+ EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>());
+}
+
+TEST(GlobalFilterTest, test_filter_docid_0_not_allowed) {
+ auto docs = std::vector<uint32_t>({0,22,33});
+ EXPECT_THAT([&](){ GlobalFilter::create(docs, 100); }, Throws<RequireFailedException>());
+}
+
+TEST(GlobalFilterTest, create_can_make_single_bitvector_filter) {
+ auto bits = BitVector::create(1, 100);
+ bits->setBit(11);
+ bits->setBit(22);
+ bits->setBit(33);
+ bits->invalidateCachedCount();
+ EXPECT_EQ(bits->countTrueBits(), 3);
+ auto filter = GlobalFilter::create(std::move(bits));
+ verify(*filter);
+}
+
+TEST(GlobalFilterTest, global_filter_pointer_guard) {
+ auto inactive = GlobalFilter::create();
+ auto active = GlobalFilter::create(BitVector::create(1,100));
+ EXPECT_TRUE(active->is_active());
+ EXPECT_FALSE(inactive->is_active());
+ EXPECT_TRUE(active->ptr_if_active() == active.get());
+ EXPECT_TRUE(inactive->ptr_if_active() == nullptr);
+}
+
+TEST(GlobalFilterTest, create_can_make_multi_bitvector_filter) {
+ std::vector<std::unique_ptr<BitVector>> bits;
+ bits.push_back(BitVector::create(1, 11));
+ bits.push_back(BitVector::create(11, 23));
+ bits.push_back(BitVector::create(23, 25));
+ bits.push_back(BitVector::create(25, 100));
+ bits[1]->setBit(11);
+ bits[1]->setBit(22);
+ bits[3]->setBit(33);
+ for (const auto &v: bits) {
+ v->invalidateCachedCount();
+ }
+ auto filter = GlobalFilter::create(std::move(bits));
+ verify(*filter);
+}
+
+TEST(GlobalFilterTest, multi_bitvector_filter_with_empty_vectors) {
+ std::vector<std::unique_ptr<BitVector>> bits;
+ bits.push_back(BitVector::create(1, 11));
+ bits.push_back(BitVector::create(11, 23));
+ bits.push_back(BitVector::create(23, 23));
+ bits.push_back(BitVector::create(23, 23));
+ bits.push_back(BitVector::create(23, 25));
+ bits.push_back(BitVector::create(25, 100));
+ bits[1]->setBit(11);
+ bits[1]->setBit(22);
+ bits[5]->setBit(33);
+ for (const auto &v: bits) {
+ v->invalidateCachedCount();
+ }
+ auto filter = GlobalFilter::create(std::move(bits));
+ verify(*filter);
+}
+
+TEST(GlobalFilterTest, multi_bitvector_filter_with_no_vectors) {
+ std::vector<std::unique_ptr<BitVector>> bits;
+ auto filter = GlobalFilter::create(std::move(bits));
+ EXPECT_TRUE(filter->is_active());
+ EXPECT_EQ(filter->size(), 0);
+ EXPECT_EQ(filter->count(), 0);
+}
+
+TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_gaps) {
+ std::vector<std::unique_ptr<BitVector>> bits;
+ bits.push_back(BitVector::create(1, 11));
+ bits.push_back(BitVector::create(12, 100));
+ EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>());
+}
+
+TEST(GlobalFilterTest, multi_bitvector_filter_requires_no_overlap) {
+ std::vector<std::unique_ptr<BitVector>> bits;
+ bits.push_back(BitVector::create(1, 11));
+ bits.push_back(BitVector::create(10, 100));
+ EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>());
+}
+
+TEST(GlobalFilterTest, multi_bitvector_filter_requires_correct_order) {
+ std::vector<std::unique_ptr<BitVector>> bits;
+ bits.push_back(BitVector::create(11, 100));
+ bits.push_back(BitVector::create(1, 11));
+ EXPECT_THAT([&](){ GlobalFilter::create(std::move(bits)); }, Throws<RequireFailedException>());
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 33435f43618..f02681908d6 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -96,12 +96,7 @@ struct Fixture
void setFilter(std::vector<uint32_t> docids) {
uint32_t sz = _attr->getNumDocs();
- auto bit_vector = BitVector::create(sz);
- for (uint32_t id : docids) {
- EXPECT_LT(id, sz);
- bit_vector->setBit(id);
- }
- _global_filter = GlobalFilter::create(std::move(bit_vector));
+ _global_filter = GlobalFilter::create(docids, sz);
}
void setTensor(uint32_t docId, const Value &tensor) {
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 193bb04843c..7877b488065 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -82,10 +82,6 @@ public:
~HnswIndexTest() {}
- const GlobalFilter *global_filter_ptr() const {
- return global_filter->is_active() ? global_filter.get() : nullptr;
- }
-
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
@@ -110,12 +106,7 @@ public:
}
void set_filter(std::vector<uint32_t> docids) {
uint32_t sz = 10;
- auto bit_vector = BitVector::create(sz);
- for (uint32_t id : docids) {
- EXPECT_LT(id, sz);
- bit_vector->setBit(id);
- }
- global_filter = GlobalFilter::create(std::move(bit_vector));
+ global_filter = GlobalFilter::create(docids, sz);
}
GenerationHandler::Guard take_read_guard() {
return gen_handler.takeGuard();
@@ -149,7 +140,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -170,7 +161,7 @@ public:
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid);
uint32_t k = 3;
- auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
index 1a5d3d3dacd..2aff91974a7 100644
--- a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "global_filter.h"
+#include <vespa/vespalib/util/require.h>
namespace search::queryeval {
@@ -23,6 +24,31 @@ struct BitVectorFilter : public GlobalFilter {
bool check(uint32_t docid) const override { return vector->testBit(docid); }
};
+struct MultiBitVectorFilter : public GlobalFilter {
+ std::vector<std::unique_ptr<BitVector>> vectors;
+ std::vector<uint32_t> splits;
+ uint32_t total_size;
+ uint32_t total_count;
+ MultiBitVectorFilter(std::vector<std::unique_ptr<BitVector>> vectors_in,
+ std::vector<uint32_t> splits_in,
+ uint32_t total_size_in,
+ uint32_t total_count_in)
+ : vectors(std::move(vectors_in)),
+ splits(std::move(splits_in)),
+ total_size(total_size_in),
+ total_count(total_count_in) {}
+ bool is_active() const override { return true; }
+ uint32_t size() const override { return total_size; }
+ uint32_t count() const override { return total_count; }
+ bool check(uint32_t docid) const override {
+ size_t i = 0;
+ while ((i < splits.size()) && (docid >= splits[i])) {
+ ++i;
+ }
+ return vectors[i]->testBit(docid);
+ }
+};
+
}
GlobalFilter::GlobalFilter() = default;
@@ -34,8 +60,44 @@ GlobalFilter::create() {
}
std::shared_ptr<GlobalFilter>
-GlobalFilter::create(std::unique_ptr<BitVector> vector) {
+GlobalFilter::create(std::vector<uint32_t> docids, uint32_t size)
+{
+ uint32_t prev = 0;
+ auto bits = BitVector::create(1, size);
+ for (uint32_t docid: docids) {
+ REQUIRE(docid > prev);
+ REQUIRE(docid < size);
+ bits->setBit(docid);
+ prev = docid;
+ }
+ bits->invalidateCachedCount();
+ return create(std::move(bits));
+}
+
+std::shared_ptr<GlobalFilter>
+GlobalFilter::create(std::unique_ptr<BitVector> vector)
+{
return std::make_shared<BitVectorFilter>(std::move(vector));
}
+std::shared_ptr<GlobalFilter>
+GlobalFilter::create(std::vector<std::unique_ptr<BitVector>> vectors)
+{
+ uint32_t total_size = 0;
+ uint32_t total_count = 0;
+ std::vector<uint32_t> splits;
+ for (size_t i = 0; i < vectors.size(); ++i) {
+ bool last = ((i + 1) == vectors.size());
+ total_count += vectors[i]->countTrueBits();
+ if (last) {
+ total_size = vectors[i]->size();
+ } else {
+ REQUIRE_EQ(vectors[i]->size(), vectors[i + 1]->getStartIndex());
+ splits.push_back(vectors[i]->size());
+ }
+ }
+ return std::make_shared<MultiBitVectorFilter>(std::move(vectors), std::move(splits),
+ total_size, total_count);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h
index c6e08d5018d..8504367e5b7 100644
--- a/searchlib/src/vespa/searchlib/queryeval/global_filter.h
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h
@@ -26,8 +26,14 @@ public:
virtual bool check(uint32_t docid) const = 0;
virtual ~GlobalFilter();
+ const GlobalFilter *ptr_if_active() const {
+ return is_active() ? this : nullptr;
+ }
+
static std::shared_ptr<GlobalFilter> create();
+ static std::shared_ptr<GlobalFilter> create(std::vector<uint32_t> docids, uint32_t size);
static std::shared_ptr<GlobalFilter> create(std::unique_ptr<BitVector> vector);
+ static std::shared_ptr<GlobalFilter> create(std::vector<std::unique_ptr<BitVector>> vectors);
};
} // namespace