diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-05-12 06:27:05 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-05-12 06:27:05 +0000 |
commit | 469f5e10bfaf09c584a001f483f2fbd249f9bcd5 (patch) | |
tree | 606f7dc50cb607ced2045d49aea467316c97150f | |
parent | 5d9e9b8e651eb51cafb4d55dcec89900b331aa4f (diff) |
unit test with filter also
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 88e35a80bc9..04f2076121f 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -1,6 +1,7 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/eval/tensor/dense/typed_cells.h> +#include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/tensor/distance_functions.h> #include <vespa/searchlib/tensor/doc_vector_access.h> #include <vespa/searchlib/tensor/hnsw_index.h> @@ -19,6 +20,7 @@ using vespalib::MemoryUsage; using namespace search::tensor; using namespace vespalib::slime; using vespalib::Slime; +using search::BitVector; template <typename FloatType> @@ -56,12 +58,14 @@ using HnswIndexUP = std::unique_ptr<HnswIndex>; class HnswIndexTest : public ::testing::Test { public: FloatVectors vectors; + std::unique_ptr<BitVector> global_filter; LevelGenerator* level_generator; GenerationHandler gen_handler; HnswIndexUP index; HnswIndexTest() : vectors(), + global_filter(), level_generator(), gen_handler(), index() @@ -95,6 +99,14 @@ public: gen_handler.updateFirstUsedGeneration(); index->trim_hold_lists(gen_handler.getFirstUsedGeneration()); } + void set_filter(std::vector<uint32_t> docids) { + uint32_t sz = 10; + global_filter = BitVector::create(sz); + for (uint32_t id : docids) { + EXPECT_LT(id, sz); + global_filter->setBit(id); + } + } GenerationHandler::Guard take_read_guard() { return gen_handler.takeGuard(); } @@ -122,7 +134,7 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid); - auto rv = index->top_k_candidates(qv, k, nullptr).peek(); + auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -197,6 +209,15 @@ TEST_F(HnswIndexTest, 2d_vectors_inserted_in_level_0_graph_with_simple_select_ne expect_top_3(7, {7, 3, 2}); expect_top_3(8, {4, 3, 1}); expect_top_3(9, {7, 3, 2}); + + set_filter({2,3,4,6}); + expect_top_3(2, {2, 3}); + expect_top_3(4, {4, 3}); + expect_top_3(5, {6, 2}); + expect_top_3(6, {6, 2}); + expect_top_3(7, {3, 2}); + expect_top_3(8, {4, 3}); + expect_top_3(9, {3, 2}); } TEST_F(HnswIndexTest, 2d_vectors_inserted_and_removed) |