summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-12 06:27:05 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-12 06:27:05 +0000
commit469f5e10bfaf09c584a001f483f2fbd249f9bcd5 (patch)
tree606f7dc50cb607ced2045d49aea467316c97150f /searchlib
parent5d9e9b8e651eb51cafb4d55dcec89900b331aa4f (diff)
unit test with filter also
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp23
1 files changed, 22 insertions, 1 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 88e35a80bc9..04f2076121f 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -1,6 +1,7 @@
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/eval/tensor/dense/typed_cells.h>
+#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/tensor/distance_functions.h>
#include <vespa/searchlib/tensor/doc_vector_access.h>
#include <vespa/searchlib/tensor/hnsw_index.h>
@@ -19,6 +20,7 @@ using vespalib::MemoryUsage;
using namespace search::tensor;
using namespace vespalib::slime;
using vespalib::Slime;
+using search::BitVector;
template <typename FloatType>
@@ -56,12 +58,14 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>;
class HnswIndexTest : public ::testing::Test {
public:
FloatVectors vectors;
+ std::unique_ptr<BitVector> global_filter;
LevelGenerator* level_generator;
GenerationHandler gen_handler;
HnswIndexUP index;
HnswIndexTest()
: vectors(),
+ global_filter(),
level_generator(),
gen_handler(),
index()
@@ -95,6 +99,14 @@ public:
gen_handler.updateFirstUsedGeneration();
index->trim_hold_lists(gen_handler.getFirstUsedGeneration());
}
+ void set_filter(std::vector<uint32_t> docids) {
+ uint32_t sz = 10;
+ global_filter = BitVector::create(sz);
+ for (uint32_t id : docids) {
+ EXPECT_LT(id, sz);
+ global_filter->setBit(id);
+ }
+ }
GenerationHandler::Guard take_read_guard() {
return gen_handler.takeGuard();
}
@@ -122,7 +134,7 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid);
- auto rv = index->top_k_candidates(qv, k, nullptr).peek();
+ auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -197,6 +209,15 @@ TEST_F(HnswIndexTest, 2d_vectors_inserted_in_level_0_graph_with_simple_select_ne
expect_top_3(7, {7, 3, 2});
expect_top_3(8, {4, 3, 1});
expect_top_3(9, {7, 3, 2});
+
+ set_filter({2,3,4,6});
+ expect_top_3(2, {2, 3});
+ expect_top_3(4, {4, 3});
+ expect_top_3(5, {6, 2});
+ expect_top_3(6, {6, 2});
+ expect_top_3(7, {3, 2});
+ expect_top_3(8, {4, 3});
+ expect_top_3(9, {3, 2});
}
TEST_F(HnswIndexTest, 2d_vectors_inserted_and_removed)