summaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/vespa/searchlib/queryeval/global_filter.cpp')
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.cpp64
1 files changed, 63 insertions, 1 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
index 1a5d3d3dacd..2aff91974a7 100644
--- a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "global_filter.h"
+#include <vespa/vespalib/util/require.h>
namespace search::queryeval {
@@ -23,6 +24,31 @@ struct BitVectorFilter : public GlobalFilter {
bool check(uint32_t docid) const override { return vector->testBit(docid); }
};
+struct MultiBitVectorFilter : public GlobalFilter {
+ std::vector<std::unique_ptr<BitVector>> vectors;
+ std::vector<uint32_t> splits;
+ uint32_t total_size;
+ uint32_t total_count;
+ MultiBitVectorFilter(std::vector<std::unique_ptr<BitVector>> vectors_in,
+ std::vector<uint32_t> splits_in,
+ uint32_t total_size_in,
+ uint32_t total_count_in)
+ : vectors(std::move(vectors_in)),
+ splits(std::move(splits_in)),
+ total_size(total_size_in),
+ total_count(total_count_in) {}
+ bool is_active() const override { return true; }
+ uint32_t size() const override { return total_size; }
+ uint32_t count() const override { return total_count; }
+ bool check(uint32_t docid) const override {
+ size_t i = 0;
+ while ((i < splits.size()) && (docid >= splits[i])) {
+ ++i;
+ }
+ return vectors[i]->testBit(docid);
+ }
+};
+
}
GlobalFilter::GlobalFilter() = default;
@@ -34,8 +60,44 @@ GlobalFilter::create() {
}
std::shared_ptr<GlobalFilter>
-GlobalFilter::create(std::unique_ptr<BitVector> vector) {
+GlobalFilter::create(std::vector<uint32_t> docids, uint32_t size)
+{
+ uint32_t prev = 0;
+ auto bits = BitVector::create(1, size);
+ for (uint32_t docid: docids) {
+ REQUIRE(docid > prev);
+ REQUIRE(docid < size);
+ bits->setBit(docid);
+ prev = docid;
+ }
+ bits->invalidateCachedCount();
+ return create(std::move(bits));
+}
+
+std::shared_ptr<GlobalFilter>
+GlobalFilter::create(std::unique_ptr<BitVector> vector)
+{
return std::make_shared<BitVectorFilter>(std::move(vector));
}
+std::shared_ptr<GlobalFilter>
+GlobalFilter::create(std::vector<std::unique_ptr<BitVector>> vectors)
+{
+ uint32_t total_size = 0;
+ uint32_t total_count = 0;
+ std::vector<uint32_t> splits;
+ for (size_t i = 0; i < vectors.size(); ++i) {
+ bool last = ((i + 1) == vectors.size());
+ total_count += vectors[i]->countTrueBits();
+ if (last) {
+ total_size = vectors[i]->size();
+ } else {
+ REQUIRE_EQ(vectors[i]->size(), vectors[i + 1]->getStartIndex());
+ splits.push_back(vectors[i]->size());
+ }
+ }
+ return std::make_shared<MultiBitVectorFilter>(std::move(vectors), std::move(splits),
+ total_size, total_count);
+}
+
}