diff options
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 25 | ||||
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp | 8 |
2 files changed, 24 insertions, 9 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index d9230849699..768157412f9 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -105,10 +105,16 @@ public: ~HnswIndexTest() {} + auto dff() { + return search::tensor::make_distance_function_factory( + search::attribute::DistanceMetric::Euclidean, + vespalib::eval::CellType::FLOAT); + } + void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); - index = std::make_unique<IndexType>(vectors, std::make_unique<SquaredEuclideanDistance>(vespalib::eval::CellType::FLOAT), + index = std::make_unique<IndexType>(vectors, dff(), std::move(generator), HnswIndexConfig(5, 2, 10, 0, heuristic_select_neighbors)); } @@ -165,9 +171,10 @@ public: uint32_t explore_k = 100; vespalib::ArrayRef qv_ref(qv); vespalib::eval::TypedCells qv_cells(qv_ref); + auto df = index->distance_function_factory().forQueryVector(qv_cells); auto got_by_docid = (global_filter->is_active()) ? - index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) : - index->find_top_k(k, qv_cells, explore_k, 10000.0); + index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) : + index->find_top_k(k, *df, explore_k, 10000.0); std::vector<uint32_t> act; act.reserve(got_by_docid.size()); for (auto& hit : got_by_docid) { @@ -178,7 +185,8 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid, 0); - auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); + auto df = index->distance_function_factory().forQueryVector(qv); + auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -189,7 +197,7 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - auto got_by_docid = index->find_top_k(k, qv, k, 100100.25); + auto got_by_docid = index->find_top_k(k, *df, k, 100100.25); for (idx = 0; idx < k; ++idx) { EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } @@ -198,15 +206,16 @@ public: } void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid, 0); + auto df = index->distance_function_factory().forQueryVector(qv); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); + auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); double thr = (rv[0].distance + rv[1].distance) * 0.5; auto got_by_docid = (global_filter->is_active()) - ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr) - : index->find_top_k(k, qv, k, thr); + ? index->find_top_k_with_filter(k, *df, *global_filter, k, thr) + : index->find_top_k(k, *df, k, thr); EXPECT_EQ(got_by_docid.size(), 1); EXPECT_EQ(got_by_docid[0].docid, index->get_docid(rv[0].nodeid)); for (const auto & hit : got_by_docid) { diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp index ecf310798af..0dcd77ec392 100644 --- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp @@ -261,9 +261,15 @@ public: ~Stressor() {} + auto dff() { + return search::tensor::make_distance_function_factory( + search::attribute::DistanceMetric::Euclidean, + vespalib::eval::CellType::FLOAT); + } + void init() { uint32_t m = 16; - index = std::make_unique<IndexType>(vectors, std::make_unique<FloatSqEuclideanDistance>(), + index = std::make_unique<IndexType>(vectors, dff(), std::make_unique<InvLogLevelGenerator>(m), HnswIndexConfig(2*m, m, 200, 10, true)); } |