aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-19 09:41:47 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-19 09:43:08 +0000
commit3880d66a21f151e97ac6fb892aa56909591e830e (patch)
tree2f4b4fb1415a7e1bab5556db550ab8bc871788a6 /searchlib/src/tests/tensor
parent4f542728c8f13882470b7bdc55fe9909fd2ffe81 (diff)
add mimimal version of BoundDistanceFunction
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp25
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp8
2 files changed, 24 insertions, 9 deletions
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index d9230849699..768157412f9 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -105,10 +105,16 @@ public:
~HnswIndexTest() {}
+ auto dff() {
+ return search::tensor::make_distance_function_factory(
+ search::attribute::DistanceMetric::Euclidean,
+ vespalib::eval::CellType::FLOAT);
+ }
+
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
- index = std::make_unique<IndexType>(vectors, std::make_unique<SquaredEuclideanDistance>(vespalib::eval::CellType::FLOAT),
+ index = std::make_unique<IndexType>(vectors, dff(),
std::move(generator),
HnswIndexConfig(5, 2, 10, 0, heuristic_select_neighbors));
}
@@ -165,9 +171,10 @@ public:
uint32_t explore_k = 100;
vespalib::ArrayRef qv_ref(qv);
vespalib::eval::TypedCells qv_cells(qv_ref);
+ auto df = index->distance_function_factory().forQueryVector(qv_cells);
auto got_by_docid = (global_filter->is_active()) ?
- index->find_top_k_with_filter(k, qv_cells, *global_filter, explore_k, 10000.0) :
- index->find_top_k(k, qv_cells, explore_k, 10000.0);
+ index->find_top_k_with_filter(k, *df, *global_filter, explore_k, 10000.0) :
+ index->find_top_k(k, *df, explore_k, 10000.0);
std::vector<uint32_t> act;
act.reserve(got_by_docid.size());
for (auto& hit : got_by_docid) {
@@ -178,7 +185,8 @@ public:
void expect_top_3(uint32_t docid, std::vector<uint32_t> exp_hits) {
uint32_t k = 3;
auto qv = vectors.get_vector(docid, 0);
- auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
+ auto df = index->distance_function_factory().forQueryVector(qv);
+ auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
size_t idx = 0;
for (const auto & hit : rv) {
@@ -189,7 +197,7 @@ public:
if (exp_hits.size() == k) {
std::vector<uint32_t> expected_by_docid = exp_hits;
std::sort(expected_by_docid.begin(), expected_by_docid.end());
- auto got_by_docid = index->find_top_k(k, qv, k, 100100.25);
+ auto got_by_docid = index->find_top_k(k, *df, k, 100100.25);
for (idx = 0; idx < k; ++idx) {
EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid);
}
@@ -198,15 +206,16 @@ public:
}
void check_with_distance_threshold(uint32_t docid) {
auto qv = vectors.get_vector(docid, 0);
+ auto df = index->distance_function_factory().forQueryVector(qv);
uint32_t k = 3;
- auto rv = index->top_k_candidates(qv, k, global_filter->ptr_if_active()).peek();
+ auto rv = index->top_k_candidates(*df, k, global_filter->ptr_if_active()).peek();
std::sort(rv.begin(), rv.end(), LesserDistance());
EXPECT_EQ(rv.size(), 3);
EXPECT_LE(rv[0].distance, rv[1].distance);
double thr = (rv[0].distance + rv[1].distance) * 0.5;
auto got_by_docid = (global_filter->is_active())
- ? index->find_top_k_with_filter(k, qv, *global_filter, k, thr)
- : index->find_top_k(k, qv, k, thr);
+ ? index->find_top_k_with_filter(k, *df, *global_filter, k, thr)
+ : index->find_top_k(k, *df, k, thr);
EXPECT_EQ(got_by_docid.size(), 1);
EXPECT_EQ(got_by_docid[0].docid, index->get_docid(rv[0].nodeid));
for (const auto & hit : got_by_docid) {
diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
index ecf310798af..0dcd77ec392 100644
--- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
@@ -261,9 +261,15 @@ public:
~Stressor() {}
+ auto dff() {
+ return search::tensor::make_distance_function_factory(
+ search::attribute::DistanceMetric::Euclidean,
+ vespalib::eval::CellType::FLOAT);
+ }
+
void init() {
uint32_t m = 16;
- index = std::make_unique<IndexType>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
+ index = std::make_unique<IndexType>(vectors, dff(),
std::make_unique<InvLogLevelGenerator>(m),
HnswIndexConfig(2*m, m, 200, 10, true));
}