diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-04-20 09:09:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-20 09:09:33 +0200 |
commit | 3cde18008cff2d1c812ec86141f538b56d5248ab (patch) | |
tree | eaca23b97d05abf775dfaf68b495cae70107c450 /searchlib/src/tests | |
parent | 4bf83d5e87a8896ce3b6a14fb0889a2891053bf1 (diff) | |
parent | 732e4c4be8bbc5a43e3adae5db222301e630bd8c (diff) |
Merge pull request #26783 from vespa-engine/arnej/refactor-with-bound-distance
add mimimal version of BoundDistanceFunction
Diffstat (limited to 'searchlib/src/tests')
4 files changed, 43 insertions, 19 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 28c50891225..e3c9e05073e 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -294,30 +294,33 @@ public: std::unique_ptr<NearestNeighborIndexLoader> make_loader(FastOS_FileInterface& file) override { return std::make_unique<MockIndexLoader>(_index_value, file); } - std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k, + std::vector<Neighbor> find_top_k(uint32_t k, + const search::tensor::BoundDistanceFunction &df, + uint32_t explore_k, double distance_threshold) const override { (void) k; - (void) vector; + (void) df; (void) explore_k; (void) distance_threshold; return std::vector<Neighbor>(); } - std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, + std::vector<Neighbor> find_top_k_with_filter(uint32_t k, + const search::tensor::BoundDistanceFunction &df, const GlobalFilter& filter, uint32_t explore_k, double distance_threshold) const override { (void) k; - (void) vector; + (void) df; (void) explore_k; (void) filter; (void) distance_threshold; return std::vector<Neighbor>(); } - const search::tensor::DistanceFunction *distance_function() const override { - static search::tensor::SquaredEuclideanDistance my_dist_fun(vespalib::eval::CellType::DOUBLE); - return &my_dist_fun; + search::tensor::DistanceFunctionFactory &distance_function_factory() const override { + static search::tensor::DistanceFunctionFactory::UP my_dist_fun = search::tensor::make_distance_function_factory(search::attribute::DistanceMetric::Euclidean, vespalib::eval::CellType::DOUBLE); + return *my_dist_fun; } }; diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 2801bf90080..fd07529795a 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -134,7 +134,9 @@ SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._attr); - DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); + + auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type); + DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells())); NearestNeighborDistanceHeap dh(2); dh.set_distance_threshold(env.dist_fun().convert_threshold(threshold)); const GlobalFilter &filter = *env._global_filter; @@ -260,7 +262,8 @@ std::vector<feature_t> get_rawscores(Fixture &env, const Value &qtv) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._attr); - DistanceCalculator dist_calc(attr, qtv, env.dist_fun()); + auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, qtv.cells().type); + DistanceCalculator dist_calc(attr, dff->for_query_vector(qtv.cells())); NearestNeighborDistanceHeap dh(2); auto dummy_filter = GlobalFilter::create(); auto search = NearestNeighborIterator::create(strict, tfmd, dist_calc, dh, *dummy_filter); @@ -333,7 +336,10 @@ TEST(NnsIndexIteratorTest, require_that_iterator_works_as_expected) { std::vector<NnsIndexIterator::Hit> hits{{2,4.0}, {3,9.0}, {5,1.0}, {8,16.0}, {9,36.0}}; auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); - auto search = NnsIndexIterator::create(tfmd, hits, *euclid_d); + auto dff = search::tensor::make_distance_function_factory(DistanceMetric::Euclidean, CellType::DOUBLE); + vespalib::eval::TypedCells dummy; + auto df = dff->for_query_vector(dummy); + auto search = NnsIndexIterator::create(tfmd, hits, *df); search->initFullRange(); expect_not_match(*search, 1, 2); expect_match(*search, 2); diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index d9230849699..9f6216f5867 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -105,10 +105,16 @@ public: ~HnswIndexTest() {} + auto dff() { + return search::tensor::make_distance_function_factory( + search::attribute::DistanceMetric::Euclidean, + vespalib::eval::CellType::FLOAT); + } + void init(bool heuristic_select_neighbors) { auto generator = std::make_unique<LevelGenerator>(); level_generator = generator.get(); - index = std::make_unique<IndexType>(vectors, std::make_unique<SquaredEuclideanDistance>(vespalib::eval::CellType::FLOAT), + index = std::make_unique<IndexType>(vectors, dff(), std::move(generator), HnswIndexConfig(5, 2, 10, 0, heuristic_select_neighbors)); } @@ -165,9 +171,10 @@ public: uint32_t explore_k = 100; vespalib::ArrayRef qv_ref(qv); vespalib::eval::TypedCells qv_cells(qv_ref); + auto df = index->distance_function_factory().for_query_vector(qv_cells); auto got_by_docid = (global_filter->is_active()) ? - index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) : - index->find_top_k(k, qv_cells, explore_k, 10000.0); + index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) : + index->find_top_k(k, *df, explore_k, 10000.0); std::vector<uint32_t> act; act.reserve(got_by_docid.size()); for (auto& hit : got_by_docid) { @@ -178,7 +185,8 @@ public: void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) { uint32_t k = 3; auto qv = vectors.get_vector(docid, 0); - auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); + auto df = index->distance_function_factory().for_query_vector(qv); + auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); size_t idx = 0; for (const auto & hit : rv) { @@ -189,7 +197,7 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - auto got_by_docid = index->find_top_k(k, qv, k, 100100.25); + auto got_by_docid = index->find_top_k(k, *df, k, 100100.25); for (idx = 0; idx < k; ++idx) { EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } @@ -198,15 +206,16 @@ public: } void check_with_distance_threshold(uint32_t docid) { auto qv = vectors.get_vector(docid, 0); + auto df = index->distance_function_factory().for_query_vector(qv); uint32_t k = 3; - auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek(); + auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek(); std::sort(rv.begin(), rv.end(), LesserDistance()); EXPECT_EQ(rv.size(), 3); EXPECT_LE(rv[0].distance, rv[1].distance); double thr = (rv[0].distance + rv[1].distance) * 0.5; auto got_by_docid = (global_filter->is_active()) - ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr) - : index->find_top_k(k, qv, k, thr); + ? index->find_top_k_with_filter(k, *df, *global_filter, k, thr) + : index->find_top_k(k, *df, k, thr); EXPECT_EQ(got_by_docid.size(), 1); EXPECT_EQ(got_by_docid[0].docid, index->get_docid(rv[0].nodeid)); for (const auto & hit : got_by_docid) { diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp index ecf310798af..0dcd77ec392 100644 --- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp @@ -261,9 +261,15 @@ public: ~Stressor() {} + auto dff() { + return search::tensor::make_distance_function_factory( + search::attribute::DistanceMetric::Euclidean, + vespalib::eval::CellType::FLOAT); + } + void init() { uint32_t m = 16; - index = std::make_unique<IndexType>(vectors, std::make_unique<FloatSqEuclideanDistance>(), + index = std::make_unique<IndexType>(vectors, dff(), std::make_unique<InvLogLevelGenerator>(m), HnswIndexConfig(2*m, m, 200, 10, true)); } |