diff options
author | Tor Egge <Tor.Egge@online.no> | 2022-11-25 12:45:11 +0100 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2022-11-25 12:45:11 +0100 |
commit | 1ce1178a7988f82fa83380b88269d26199bba799 (patch) | |
tree | 960c8d41594baebb546cfdc89c8f866ab7e5c1be /searchlib | |
parent | 59120a5bcb6f0e215b6486f7da6399f9f96630b6 (diff) |
Unit test hnsw index search with multiple nodes per document.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 55 |
1 files changed, 47 insertions, 8 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index bea371c78a8..06bca414f84 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -62,14 +62,13 @@ public: return *this; } vespalib::eval::TypedCells get_vector(uint32_t docid, uint32_t subspace) const override { - (void) subspace; - ArrayRef ref(_vectors[docid]); - return vespalib::eval::TypedCells(ref); + return get_vectors(docid).cells(subspace); } VectorBundle get_vectors(uint32_t docid) const override { ArrayRef ref(_vectors[docid]); - assert(_subspace_type.size() == ref.size()); - return VectorBundle(ref.data(), 1, _subspace_type); + assert((ref.size() % _subspace_type.size()) == 0); + uint32_t subspaces = ref.size() / _subspace_type.size(); + return VectorBundle(ref.data(), subspaces, _subspace_type); } void clear() { _vectors.clear(); } @@ -160,6 +159,20 @@ public: ASSERT_EQ(exp_levels.size(), act_node.size()); EXPECT_EQ(exp_levels, act_node.levels()); } + void expect_top_3_by_docid(const vespalib::string& label, std::vector<float> qv, std::vector<uint32_t> exp) { + SCOPED_TRACE(label); + uint32_t k = 3; + uint32_t explore_k = 100; + vespalib::ArrayRef qv_ref(qv); + vespalib::eval::TypedCells qv_cells(qv_ref); + auto got_by_docid = index->find_top_k(k, qv_cells, explore_k, 10000.0); + std::vector<uint32_t> act; + act.reserve(got_by_docid.size()); + for (auto& hit : got_by_docid) { + act.emplace_back(hit.docid); + } + EXPECT_EQ(exp, act); + } void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid, 0); @@ -740,7 +753,35 @@ TYPED_TEST(HnswIndexTest, hnsw_graph_can_be_saved_and_loaded) this->init(false); this->load_index(data); this->check_savetest_index("after load"); - } +} + +using HnswMultiIndexTest = HnswIndexTest<HnswIndex<HnswIndexType::MULTI>>; + +TEST_F(HnswMultiIndexTest, duplicate_docid_is_removed) +{ + this->init(false); + this->vectors + .set(1, {0, 0, 0, 2}) + .set(2, {1, 0}) + .set(3, {1, 2}) + .set(4, {2, 0, 2, 2}); + /* + * 1 3 4 + * . . . + * 1 2 4 + */ + for (uint32_t docid = 1; docid < 5; ++docid) { + this->add_document(docid); + } + this->expect_top_3_by_docid("{0, 0}", {0, 0}, {1, 2, 4}); + this->expect_top_3_by_docid("{0, 1}", {0, 1}, {1, 2, 3}); + this->expect_top_3_by_docid("{0, 2}", {0, 2}, {1, 3, 4}); + this->expect_top_3_by_docid("{1, 0}", {1, 0}, {1, 2, 4}); + this->expect_top_3_by_docid("{1, 2}", {1, 2}, {1, 3, 4}); + this->expect_top_3_by_docid("{2, 0}", {2, 0}, {1, 2, 4}); + this->expect_top_3_by_docid("{2, 1}", {2, 1}, {2, 3, 4}); + this->expect_top_3_by_docid("{2, 2}", {2, 2}, {1, 3, 4}); +}; TEST(LevelGeneratorTest, gives_various_levels) { @@ -789,7 +830,6 @@ TEST(LevelGeneratorTest, gives_various_levels) EXPECT_TRUE(hist.size() < 14); } - template <typename IndexType> class TwoPhaseTest : public HnswIndexTest<IndexType> { public: @@ -853,5 +893,4 @@ TYPED_TEST(TwoPhaseTest, two_phase_add) this->expect_levels(nodeids[0], {{2}, {4}}); } - GTEST_MAIN_RUN_ALL_TESTS() |