summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp2
-rw-r--r--searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp2
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp15
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp21
4 files changed, 25 insertions, 15 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index b93398e16a1..ebe96035fc8 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -260,7 +260,7 @@ public:
return std::vector<Neighbor>();
}
std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector,
- const search::BitVector& filter, uint32_t explore_k,
+ const GlobalFilter& filter, uint32_t explore_k,
double distance_threshold) const override
{
(void) k;
diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
index 87de62dbfad..2379213b87b 100644
--- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
+++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
@@ -80,7 +80,7 @@ TEST("test AndNot Blueprint") {
a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create()));
EXPECT_EQUAL(true, a.getState().want_global_filter());
auto empty_global_filter = GlobalFilter::create();
- EXPECT_FALSE(empty_global_filter->has_filter());
+ EXPECT_FALSE(empty_global_filter->is_active());
a.set_global_filter(*empty_global_filter, 1.0);
EXPECT_EQUAL(false, got_global_filter(a.getChild(0)));
EXPECT_EQUAL(true, got_global_filter(a.getChild(1)));
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 1e341eab707..33435f43618 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -9,6 +9,7 @@
#include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h>
#include <vespa/searchlib/queryeval/nns_index_iterator.h>
#include <vespa/searchlib/queryeval/simpleresult.h>
+#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/distance_function_factory.h>
@@ -63,7 +64,7 @@ struct Fixture
vespalib::string _typeSpec;
std::shared_ptr<DenseTensorAttribute> _tensorAttr;
std::shared_ptr<AttributeVector> _attr;
- std::unique_ptr<BitVector> _global_filter;
+ std::shared_ptr<GlobalFilter> _global_filter;
Fixture(const vespalib::string &typeSpec)
: _cfg(BasicType::TENSOR, CollectionType::SINGLE),
@@ -71,7 +72,7 @@ struct Fixture
_typeSpec(typeSpec),
_tensorAttr(),
_attr(),
- _global_filter()
+ _global_filter(GlobalFilter::create())
{
_cfg.setTensorType(ValueType::from_spec(typeSpec));
_tensorAttr = makeAttr();
@@ -95,11 +96,12 @@ struct Fixture
void setFilter(std::vector<uint32_t> docids) {
uint32_t sz = _attr->getNumDocs();
- _global_filter = BitVector::create(sz);
+ auto bit_vector = BitVector::create(sz);
for (uint32_t id : docids) {
EXPECT_LT(id, sz);
- _global_filter->setBit(id);
+ bit_vector->setBit(id);
}
+ _global_filter = GlobalFilter::create(std::move(bit_vector));
}
void setTensor(uint32_t docId, const Value &tensor) {
@@ -130,7 +132,7 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std
DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold));
- const BitVector *filter = env._global_filter.get();
+ const GlobalFilter &filter = *env._global_filter;
auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, filter);
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
@@ -222,7 +224,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) {
auto &attr = *(env._tensorAttr);
DistanceCalculator dist_calc(attr, qtv, env.dist_fun());
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, nullptr);
+ auto dummy_filter = GlobalFilter::create();
+ auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter);
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 3d1127e6bc4..193bb04843c 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -6,6 +6,7 @@
#include <vespa/searchlib/tensor/hnsw_index.h>
#include <vespa/searchlib/tensor/random_level_generator.h>
#include <vespa/searchlib/tensor/inv_log_level_generator.h>
+#include <vespa/searchlib/queryeval/global_filter.h>
#include <vespa/vespalib/datastore/compaction_spec.h>
#include <vespa/vespalib/datastore/compaction_strategy.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -24,6 +25,7 @@ using vespalib::Slime;
using search::BitVector;
using vespalib::datastore::CompactionSpec;
using vespalib::datastore::CompactionStrategy;
+using search::queryeval::GlobalFilter;
template <typename FloatType>
class MyDocVectorAccess : public DocVectorAccess {
@@ -61,14 +63,14 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>;
class HnswIndexTest : public ::testing::Test {
public:
FloatVectors vectors;
- std::unique_ptr<BitVector> global_filter;
+ std::shared_ptr<GlobalFilter> global_filter;
LevelGenerator* level_generator;
GenerationHandler gen_handler;
HnswIndexUP index;
HnswIndexTest()
: vectors(),
- global_filter(),
+ global_filter(GlobalFilter::create()),
level_generator(),
gen_handler(),
index()
@@ -80,6 +82,10 @@ public:
~HnswIndexTest() {}
+ const GlobalFilter *global_filter_ptr() const {
+ return global_filter->is_active() ? global_filter.get() : nullptr;
+ }
+
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
@@ -104,11 +110,12 @@ public:
}
void set_filter(std::vector<uint32_t> docids) {
uint32_t sz = 10;
- global_filter = BitVector::create(sz);
+ auto bit_vector = BitVector::create(sz);
for (uint32_t id : docids) {
EXPECT_LT(id, sz);
- global_filter->setBit(id);
+ bit_vector->setBit(id);
}
+ global_filter = GlobalFilter::create(std::move(bit_vector));
}
GenerationHandler::Guard take_read_guard() {
return gen_handler.takeGuard();
@@ -142,7 +149,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -163,12 +170,12 @@ public:
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid);
uint32_t k = 3;
- auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter_ptr()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);
double thr = (rv[0].distance + rv[1].distance) * 0.5;
- auto got_by_docid = (global_filter)
+ auto got_by_docid = (global_filter->is_active())
? index->find_top_k_with_filter(k, qv, *global_filter, k, thr)
: index->find_top_k(k, qv, k, thr);
EXPECT_EQ(got_by_docid.size(), 1);