summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2022-09-12 08:26:02 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2022-09-12 08:27:19 +0000
commit58fbe6f2e0d30ba239036987134a246828246542 (patch)
treea20efd64ec8d27bf6c54f61c99cab0cebfed67f9
parent0d13b7f7c153a36f785978463f19d58157360639 (diff)
GlobalFilter is now an interface
instead of a shared optional BitVector
-rw-r--r--searchcore/src/tests/proton/matching/query_test.cpp13
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp2
-rw-r--r--searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp2
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp15
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp21
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.cpp38
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.h41
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h8
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp19
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h12
-rw-r--r--searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h9
13 files changed, 118 insertions, 83 deletions
diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp
index eacb4108686..78bb679a7dc 100644
--- a/searchcore/src/tests/proton/matching/query_test.cpp
+++ b/searchcore/src/tests/proton/matching/query_test.cpp
@@ -1158,21 +1158,20 @@ Test::global_filter_is_calculated_and_handled()
auto res = Query::handle_global_filter(bp, docid_limit, 0, 0.3, nullptr);
EXPECT_TRUE(res);
EXPECT_TRUE(bp.filter);
- EXPECT_TRUE(bp.filter->has_filter());
+ EXPECT_TRUE(bp.filter->is_active());
EXPECT_EQUAL(0.3, bp.estimated_hit_ratio);
- auto* bv = bp.filter->filter();
- EXPECT_EQUAL(3u, bv->countTrueBits());
- EXPECT_TRUE(bv->testBit(3));
- EXPECT_TRUE(bv->testBit(5));
- EXPECT_TRUE(bv->testBit(7));
+ EXPECT_EQUAL(3u, bp.filter->count());
+ EXPECT_TRUE(bp.filter->check(3));
+ EXPECT_TRUE(bp.filter->check(5));
+ EXPECT_TRUE(bp.filter->check(7));
}
{ // estimated_hit_ratio > global_filter_upper_limit
GlobalFilterBlueprint bp(result, true);
auto res = Query::handle_global_filter(bp, docid_limit, 0, 0.29, nullptr);
EXPECT_TRUE(res);
EXPECT_TRUE(bp.filter);
- EXPECT_FALSE(bp.filter->has_filter());
+ EXPECT_FALSE(bp.filter->is_active());
EXPECT_EQUAL(0.3, bp.estimated_hit_ratio);
}
}
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index b93398e16a1..ebe96035fc8 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -260,7 +260,7 @@ public:
return std::vector<Neighbor>();
}
std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector,
- const search::BitVector& filter, uint32_t explore_k,
+ const GlobalFilter& filter, uint32_t explore_k,
double distance_threshold) const override
{
(void) k;
diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
index 87de62dbfad..2379213b87b 100644
--- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
+++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
@@ -80,7 +80,7 @@ TEST("test AndNot Blueprint") {
a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create()));
EXPECT_EQUAL(true, a.getState().want_global_filter());
auto empty_global_filter = GlobalFilter::create();
- EXPECT_FALSE(empty_global_filter->has_filter());
+ EXPECT_FALSE(empty_global_filter->is_active());
a.set_global_filter(*empty_global_filter, 1.0);
EXPECT_EQUAL(false, got_global_filter(a.getChild(0)));
EXPECT_EQUAL(true, got_global_filter(a.getChild(1)));
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 1e341eab707..33435f43618 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -9,6 +9,7 @@
#include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h>
#include <vespa/searchlib/queryeval/nns_index_iterator.h>
#include <vespa/searchlib/queryeval/simpleresult.h>
+#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/distance_function_factory.h>
@@ -63,7 +64,7 @@ struct Fixture
vespalib::string _typeSpec;
std::shared_ptr<DenseTensorAttribute> _tensorAttr;
std::shared_ptr<AttributeVector> _attr;
- std::unique_ptr<BitVector> _global_filter;
+ std::shared_ptr<GlobalFilter> _global_filter;
Fixture(const vespalib::string &typeSpec)
: _cfg(BasicType::TENSOR, CollectionType::SINGLE),
@@ -71,7 +72,7 @@ struct Fixture
_typeSpec(typeSpec),
_tensorAttr(),
_attr(),
- _global_filter()
+ _global_filter(GlobalFilter::create())
{
_cfg.setTensorType(ValueType::from_spec(typeSpec));
_tensorAttr = makeAttr();
@@ -95,11 +96,12 @@ struct Fixture
void setFilter(std::vector<uint32_t> docids) {
uint32_t sz = _attr->getNumDocs();
- _global_filter = BitVector::create(sz);
+ auto bit_vector = BitVector::create(sz);
for (uint32_t id : docids) {
EXPECT_LT(id, sz);
- _global_filter->setBit(id);
+ bit_vector->setBit(id);
}
+ _global_filter = GlobalFilter::create(std::move(bit_vector));
}
void setTensor(uint32_t docId, const Value &tensor) {
@@ -130,7 +132,7 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std
DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold));
- const BitVector *filter = env._global_filter.get();
+ const GlobalFilter &filter = *env._global_filter;
auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter);
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
@@ -222,7 +224,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) {
auto &attr = *(env._tensorAttr);
DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr);
+ auto dummy_filter = GlobalFilter::create();
+ auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter);
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 3d1127e6bc4..193bb04843c 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -6,6 +6,7 @@
#include <vespa/searchlib/tensor/hnsw_index.h>
#include <vespa/searchlib/tensor/random_level_generator.h>
#include <vespa/searchlib/tensor/inv_log_level_generator.h>
+#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/vespalib/datastore/compaction_spec.h>
#include <vespa/vespalib/datastore/compaction_strategy.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -24,6 +25,7 @@ using vespalib::Slime;
using search::BitVector;
using vespalib::datastore::CompactionSpec;
using vespalib::datastore::CompactionStrategy;
+using search::queryeval::GlobalFilter;
template <typename FloatType>
class MyDocVectorAccess : public DocVectorAccess {
@@ -61,14 +63,14 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>;
class HnswIndexTest : public ::testing::Test {
public:
FloatVectors vectors;
- std::unique_ptr<BitVector> global_filter;
+ std::shared_ptr<GlobalFilter> global_filter;
LevelGenerator* level_generator;
GenerationHandler gen_handler;
HnswIndexUP index;
HnswIndexTest()
: vectors(),
- global_filter(),
+ global_filter(GlobalFilter::create()),
level_generator(),
gen_handler(),
index()
@@ -80,6 +82,10 @@ public:
~HnswIndexTest() {}
+ const GlobalFilter *global_filter_ptr() const {
+ return global_filter->is_active() ? global_filter.get() : nullptr;
+ }
+
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
@@ -104,11 +110,12 @@ public:
}
void set_filter(std::vector<uint32_t> docids) {
uint32_t sz = 10;
- global_filter = BitVector::create(sz);
+ auto bit_vector = BitVector::create(sz);
for (uint32_t id : docids) {
EXPECT_LT(id, sz);
- global_filter->setBit(id);
+ bit_vector->setBit(id);
}
+ global_filter = GlobalFilter::create(std::move(bit_vector));
}
GenerationHandler::Guard take_read_guard() {
return gen_handler.takeGuard();
@@ -142,7 +149,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -163,12 +170,12 @@ public:
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid);
uint32_t k = 3;
- auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);
double thr = (rv[0].distance + rv[1].distance) * 0.5;
- auto got_by_docid = (global_filter)
+ auto got_by_docid = (global_filter->is_active())
? index->find_top_k_with_filter(k, qv, *global_filter, k, thr)
: index->find_top_k(k, qv, k, thr);
EXPECT_EQ(got_by_docid.size(), 1);
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
index b1995c7ab1c..1a5d3d3dacd 100644
--- a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
@@ -1,3 +1,41 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "global_filter.h"
+
+namespace search::queryeval {
+
+namespace {
+
+struct Inactive : GlobalFilter {
+ bool is_active() const override { return false; }
+ uint32_t size() const override { abort(); }
+ uint32_t count() const override { abort(); }
+ bool check(uint32_t) const override { abort(); }
+};
+
+struct BitVectorFilter : public GlobalFilter {
+ std::unique_ptr<BitVector> vector;
+ BitVectorFilter(std::unique_ptr<BitVector> vector_in)
+ : vector(std::move(vector_in)) {}
+ bool is_active() const override { return true; }
+ uint32_t size() const override { return vector->size(); }
+ uint32_t count() const override { return vector->countTrueBits(); }
+ bool check(uint32_t docid) const override { return vector->testBit(docid); }
+};
+
+}
+
+GlobalFilter::GlobalFilter() = default;
+GlobalFilter::~GlobalFilter() = default;
+
+std::shared_ptr<GlobalFilter>
+GlobalFilter::create() {
+ return std::make_shared<Inactive>();
+}
+
+std::shared_ptr<GlobalFilter>
+GlobalFilter::create(std::unique_ptr<BitVector> vector) {
+ return std::make_shared<BitVectorFilter>(std::move(vector));
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h
index 9a2a77ed119..c6e08d5018d 100644
--- a/searchlib/src/vespa/searchlib/queryeval/global_filter.h
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h
@@ -8,39 +8,26 @@
namespace search::queryeval {
/**
- * Hold ownership of a global filter that can be taken
- * into account by adaptive query operators. The owned
- * bitvector should be a white-list (documents that may
- * possibly become hits have their bit set, documents
- * that are certain to be filtered away should have theirs
- * cleared).
+ * Hold ownership of a global filter that can be taken into account by
+ * adaptive query operators. The owned 'bitvector' should be a
+ * white-list (documents that may possibly become hits have their bit
+ * set, documents that are certain to be filtered away should have
+ * theirs cleared).
**/
class GlobalFilter : public std::enable_shared_from_this<GlobalFilter>
{
-private:
- struct ctor_tag {};
- std::unique_ptr<search::BitVector> bit_vector;
-
public:
+ GlobalFilter();
GlobalFilter(const GlobalFilter &) = delete;
GlobalFilter(GlobalFilter &&) = delete;
-
- GlobalFilter(ctor_tag, std::unique_ptr<search::BitVector> bit_vector_in) noexcept
- : bit_vector(std::move(bit_vector_in))
- {}
-
- GlobalFilter(ctor_tag) noexcept : bit_vector() {}
-
- ~GlobalFilter() {}
-
- template<typename ... Params>
- static std::shared_ptr<GlobalFilter> create(Params&& ... params) {
- return std::make_shared<GlobalFilter>(ctor_tag(), std::forward<Params>(params)...);
- }
-
- const search::BitVector *filter() const { return bit_vector.get(); }
-
- bool has_filter() const { return bool(bit_vector); }
+ virtual bool is_active() const = 0;
+ virtual uint32_t size() const = 0;
+ virtual uint32_t count() const = 0;
+ virtual bool check(uint32_t docid) const = 0;
+ virtual ~GlobalFilter();
+
+ static std::shared_ptr<GlobalFilter> create();
+ static std::shared_ptr<GlobalFilter> create(std::unique_ptr<BitVector> vector);
};
} // namespace
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index 6a891341afd..993156e04e6 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -80,8 +80,8 @@ NearestNeighborBlueprint::set_global_filter(const GlobalFilter &global_filter, d
auto nns_index = _attr_tensor.nearest_neighbor_index();
if (_approximate && nns_index) {
uint32_t est_hits = _attr_tensor.get_num_docs();
- if (_global_filter->has_filter()) { // pre-filtering case
- _global_filter_hits = _global_filter->filter()->countTrueBits();
+ if (_global_filter->is_active()) { // pre-filtering case
+ _global_filter_hits = _global_filter->count();
_global_filter_hit_ratio = static_cast<double>(_global_filter_hits.value()) / est_hits;
if (_global_filter_hit_ratio.value() < _global_filter_lower_limit) {
_algorithm = Algorithm::EXACT_FALLBACK;
@@ -108,9 +108,8 @@ NearestNeighborBlueprint::perform_top_k(const search::tensor::NearestNeighborInd
{
auto lhs = _query_tensor.cells();
uint32_t k = _adjusted_target_hits;
- if (_global_filter->has_filter()) {
- auto filter = _global_filter->filter();
- _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold);
+ if (_global_filter->is_active()) {
+ _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits, _distance_threshold);
_algorithm = Algorithm::INDEX_TOP_K_WITH_FILTER;
} else {
_found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold);
@@ -131,7 +130,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
;
}
return NearestNeighborIterator::create(strict, tfmd, *_distance_calc,
- _distance_heap, _global_filter->filter());
+ _distance_heap, *_global_filter);
}
void
@@ -151,7 +150,7 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const
visitor.openStruct("global_filter", "GlobalFilter");
visitor.visitBool("wanted", getState().want_global_filter());
visitor.visitBool("set", _global_filter_set);
- visitor.visitBool("calculated", _global_filter->has_filter());
+ visitor.visitBool("calculated", _global_filter->is_active());
visitor.visitFloat("lower_limit", _global_filter_lower_limit);
visitor.visitFloat("upper_limit", _global_filter_upper_limit);
if (_global_filter_hits.has_value()) {
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index e06fcc614d8..b3f8195676d 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "nearest_neighbor_iterator.h"
+#include "global_filter.h"
#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/distance_function.h>
@@ -47,7 +48,7 @@ public:
void doSeek(uint32_t docId) override {
double distanceLimit = params().distanceHeap.distanceLimit();
while (__builtin_expect((docId < getEndId()), true)) {
- if ((!has_filter) || params().filter->testBit(docId)) {
+ if ((!has_filter) || params().filter.check(docId)) {
double d = computeDistance(docId, distanceLimit);
if (d <= distanceLimit) {
_lastScore = d;
@@ -106,11 +107,10 @@ NearestNeighborIterator::create(
fef::TermFieldMatchData &tfmd,
const search::tensor::DistanceCalculator &distance_calc,
NearestNeighborDistanceHeap &distanceHeap,
- const search::BitVector *filter)
-
+ const GlobalFilter &filter)
{
Params params(tfmd, distance_calc, distanceHeap, filter);
- if (filter) {
+ if (filter.is_active()) {
return resolve_strict<true>(strict, params);
} else {
return resolve_strict<false>(strict, params);
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
index 0d8f70d15c2..f06e62f9cc1 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
@@ -14,6 +14,8 @@ namespace search::tensor { class DistanceCalculator; }
namespace search::queryeval {
+class GlobalFilter;
+
class NearestNeighborIterator : public SearchIterator
{
public:
@@ -24,12 +26,12 @@ public:
fef::TermFieldMatchData &tfmd;
const search::tensor::DistanceCalculator &distance_calc;
NearestNeighborDistanceHeap &distanceHeap;
- const search::BitVector *filter;
+ const GlobalFilter &filter;
Params(fef::TermFieldMatchData &tfmd_in,
const search::tensor::DistanceCalculator &distance_calc_in,
NearestNeighborDistanceHeap &distanceHeap_in,
- const search::BitVector *filter_in)
+ const GlobalFilter &filter_in)
: tfmd(tfmd_in),
distance_calc(distance_calc_in),
distanceHeap(distanceHeap_in),
@@ -46,7 +48,7 @@ public:
fef::TermFieldMatchData &tfmd,
const search::tensor::DistanceCalculator &distance_calc,
NearestNeighborDistanceHeap &distanceHeap,
- const search::BitVector *filter);
+ const GlobalFilter &filter);
const Params& params() const { return _params; }
private:
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 2ee1b268449..fa6c9a347aa 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -9,6 +9,7 @@
#include "random_level_generator.h"
#include <vespa/searchlib/attribute/address_space_components.h>
#include <vespa/searchlib/attribute/address_space_usage.h>
+#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/searchlib/util/fileutil.h>
#include <vespa/searchlib/util/state_explorer_utils.h>
#include <vespa/vespalib/data/slime/cursor.h>
@@ -214,7 +215,7 @@ HnswIndex::calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const
}
uint32_t
-HnswIndex::estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const search::BitVector* filter) const
+HnswIndex::estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const GlobalFilter* filter) const
{
uint32_t m_for_level = max_links_for_level(level);
uint64_t base_estimate = uint64_t(m_for_level) * neighbors_to_find + 100;
@@ -224,7 +225,7 @@ HnswIndex::estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_
if (!filter) {
return base_estimate;
}
- uint32_t true_bits = filter->countTrueBits();
+ uint32_t true_bits = filter->count();
if (true_bits == 0) {
return doc_id_limit;
}
@@ -260,7 +261,7 @@ HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& e
template <class VisitedTracker>
void
HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find,
- FurthestPriQ& best_neighbors, uint32_t level, const search::BitVector *filter,
+ FurthestPriQ& best_neighbors, uint32_t level, const GlobalFilter *filter,
uint32_t doc_id_limit, uint32_t estimated_visited_nodes) const
{
NearestPriQ candidates;
@@ -271,7 +272,7 @@ HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_fi
}
candidates.push(entry);
visited.mark(entry.docid);
- if (filter && !filter->testBit(entry.docid)) {
+ if (filter && !filter->check(entry.docid)) {
assert(best_neighbors.size() == 1);
best_neighbors.pop();
}
@@ -297,7 +298,7 @@ HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_fi
double dist_to_input = calc_distance(input, neighbor_docid);
if (dist_to_input < limit_dist) {
candidates.emplace(neighbor_docid, neighbor_ref, dist_to_input);
- if ((!filter) || filter->testBit(neighbor_docid)) {
+ if ((!filter) || filter->check(neighbor_docid)) {
best_neighbors.emplace(neighbor_docid, neighbor_ref, dist_to_input);
if (best_neighbors.size() > neighbors_to_find) {
best_neighbors.pop();
@@ -311,7 +312,7 @@ HnswIndex::search_layer_helper(const TypedCells& input, uint32_t neighbors_to_fi
void
HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find,
- FurthestPriQ& best_neighbors, uint32_t level, const search::BitVector *filter) const
+ FurthestPriQ& best_neighbors, uint32_t level, const GlobalFilter *filter) const
{
uint32_t doc_id_limit = _graph.node_refs_size.load(std::memory_order_acquire);
if (filter) {
@@ -698,7 +699,7 @@ struct NeighborsByDocId {
std::vector<NearestNeighborIndex::Neighbor>
HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector,
- const BitVector *filter, uint32_t explore_k,
+ const GlobalFilter *filter, uint32_t explore_k,
double distance_threshold) const
{
std::vector<Neighbor> result;
@@ -724,14 +725,14 @@ HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k,
std::vector<NearestNeighborIndex::Neighbor>
HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector,
- const BitVector &filter, uint32_t explore_k,
+ const GlobalFilter &filter, uint32_t explore_k,
double distance_threshold) const
{
return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold);
}
FurthestPriQ
-HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const
+HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const
{
FurthestPriQ best_neighbors;
auto entry = _graph.get_entry_node();
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index 3f5a9d514ed..e3ffada1fc2 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -135,7 +135,7 @@ protected:
double calc_distance(uint32_t lhs_docid, uint32_t rhs_docid) const;
double calc_distance(const TypedCells& lhs, uint32_t rhs_docid) const;
- uint32_t estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const search::BitVector* filter) const;
+ uint32_t estimate_visited_nodes(uint32_t level, uint32_t doc_id_limit, uint32_t neighbors_to_find, const GlobalFilter* filter) const;
/**
* Performs a greedy search in the given layer to find the candidate that is nearest the input vector.
@@ -143,13 +143,13 @@ protected:
HnswCandidate find_nearest_in_layer(const TypedCells& input, const HnswCandidate& entry_point, uint32_t level) const;
template <class VisitedTracker>
void search_layer_helper(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors,
- uint32_t level, const search::BitVector *filter,
+ uint32_t level, const GlobalFilter *filter,
uint32_t doc_id_limit,
uint32_t estimated_visited_nodes) const;
void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors,
- uint32_t level, const search::BitVector *filter = nullptr) const;
+ uint32_t level, const GlobalFilter *filter = nullptr) const;
std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector,
- const BitVector *filter, uint32_t explore_k,
+ const GlobalFilter *filter, uint32_t explore_k,
double distance_threshold) const;
struct PreparedFirstAddDoc : public PrepareResult {};
@@ -206,11 +206,11 @@ public:
std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k,
double distance_threshold) const override;
std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector,
- const BitVector &filter, uint32_t explore_k,
+ const GlobalFilter &filter, uint32_t explore_k,
double distance_threshold) const override;
const DistanceFunction *distance_function() const override { return _distance_func.get(); }
- FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const;
+ FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const GlobalFilter *filter) const;
uint32_t get_entry_docid() const { return _graph.get_entry_node().docid; }
int32_t get_entry_level() const { return _graph.get_entry_node().level; }
diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
index 530d3e1036d..51d66fdd14d 100644
--- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h
@@ -20,10 +20,8 @@ namespace vespalib::slime { struct Inserter; }
namespace search::fileutil { class LoadedBuffer; }
-namespace search {
-class AddressSpaceUsage;
-class BitVector;
-}
+namespace search { class AddressSpaceUsage; }
+namespace search::queryeval { class GlobalFilter; }
namespace search::tensor {
@@ -35,6 +33,7 @@ class NearestNeighborIndexSaver;
*/
class NearestNeighborIndex {
public:
+ using GlobalFilter = search::queryeval::GlobalFilter;
using CompactionSpec = vespalib::datastore::CompactionSpec;
using CompactionStrategy = vespalib::datastore::CompactionStrategy;
using generation_t = vespalib::GenerationHandler::generation_t;
@@ -101,7 +100,7 @@ public:
// only return neighbors where the corresponding filter bit is set
virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k,
vespalib::eval::TypedCells vector,
- const BitVector &filter,
+ const GlobalFilter &filter,
uint32_t explore_k,
double distance_threshold) const = 0;