diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-01-07 12:42:17 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-01-08 10:56:14 +0000 |
commit | cf199f338efafad8c0af7de48094bd3d0037b96a (patch) | |
tree | 2c63ecee902b297006edb41a67925e721bfddff6 /searchlib/src/tests/tensor | |
parent | 8aa9ffda4324ddd5baff87be858063c6399a26ca (diff) |
add distanceThreshold option for nearestNeighbor operator
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r-- | searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp | 36 | ||||
-rw-r--r-- | searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp | 11 |
2 files changed, 44 insertions, 3 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 06fb95089fd..ee0a2aff80e 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -24,10 +24,12 @@ void verify_geo_miles(const DistanceFunction *dist_fun, TypedCells t2(p2); double abstract_distance = dist_fun->calc(t1, t2); double raw_score = dist_fun->to_rawscore(abstract_distance); - double m = ((1.0/raw_score)-1.0); - double d_miles = m / 1.609344; + double km = ((1.0/raw_score)-1.0); + double d_miles = km / 1.609344; EXPECT_GE(d_miles, exp_miles*0.99); EXPECT_LE(d_miles, exp_miles*1.01); + double threshold = dist_fun->convert_threshold(km); + EXPECT_DOUBLE_EQ(threshold, abstract_distance); } @@ -50,6 +52,10 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) double d12 = euclid->calc(t(p1), t(p2)); EXPECT_EQ(d12, 2.0); EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0))); + double threshold = euclid->convert_threshold(8.0); + EXPECT_EQ(threshold, 64.0); + threshold = euclid->convert_threshold(0.5); + EXPECT_EQ(threshold, 0.25); } TEST(DistanceFunctionsTest, angular_gives_expected_score) @@ -75,19 +81,28 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_DOUBLE_EQ(a23, 1.0); EXPECT_FLOAT_EQ(angular->to_rawscore(a12), 1.0/(1.0 + pi/2)); + double threshold = angular->convert_threshold(pi/2); + EXPECT_DOUBLE_EQ(threshold, 1.0); + double a14 = angular->calc(t(p1), t(p4)); double a24 = angular->calc(t(p2), t(p4)); EXPECT_FLOAT_EQ(a14, 0.5); EXPECT_FLOAT_EQ(a24, 0.5); EXPECT_FLOAT_EQ(angular->to_rawscore(a14), 1.0/(1.0 + pi/3)); + threshold = angular->convert_threshold(pi/3); + EXPECT_DOUBLE_EQ(threshold, 0.5); double a34 = angular->calc(t(p3), t(p4)); EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107)); EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4)); + threshold = angular->convert_threshold(pi/4); + EXPECT_FLOAT_EQ(threshold, a34); double a25 = angular->calc(t(p2), t(p5)); EXPECT_DOUBLE_EQ(a25, 2.0); EXPECT_FLOAT_EQ(angular->to_rawscore(a25), 1.0/(1.0 + pi)); + threshold = angular->convert_threshold(pi); + EXPECT_FLOAT_EQ(threshold, 2.0); double a44 = angular->calc(t(p4), t(p4)); EXPECT_GE(a44, 0.0); @@ -98,6 +113,8 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_GE(a66, 0.0); EXPECT_LT(a66, 0.000001); EXPECT_FLOAT_EQ(angular->to_rawscore(a66), 1.0); + threshold = angular->convert_threshold(0.0); + EXPECT_FLOAT_EQ(threshold, 0.0); double a16 = angular->calc(t(p1), t(p6)); double a26 = angular->calc(t(p2), t(p6)); @@ -127,6 +144,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) EXPECT_DOUBLE_EQ(i12, 1.0); EXPECT_DOUBLE_EQ(i13, 1.0); EXPECT_DOUBLE_EQ(i23, 1.0); + double i14 = innerproduct->calc(t(p1), t(p4)); double i24 = innerproduct->calc(t(p2), t(p4)); EXPECT_DOUBLE_EQ(i14, 0.5); @@ -140,6 +158,13 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) double i44 = innerproduct->calc(t(p4), t(p4)); EXPECT_GE(i44, 0.0); EXPECT_LT(i44, 0.000001); + + double threshold = innerproduct->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = innerproduct->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = innerproduct->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); } TEST(DistanceFunctionsTest, hamming_gives_expected_score) @@ -180,6 +205,13 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double d25 = hamming->calc(t(points[2]), t(points[5])); EXPECT_EQ(d25, 1.0); EXPECT_DOUBLE_EQ(hamming->to_rawscore(d25), 1.0/(1.0 + 1.0)); + + double threshold = hamming->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = hamming->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = hamming->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); } TEST(GeoDegreesTest, gives_expected_score) diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index acc157709c0..d081c299a43 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -144,11 +144,20 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - auto got_by_docid = index->find_top_k(k, qv, k); + auto got_by_docid = index->find_top_k(k, qv, k, 100100.25); for (idx = 0; idx < k; ++idx) { EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } } + if ((rv.size() > 1) && (rv[0].distance < rv[1].distance)) { + double thr = (rv[0].distance + rv[1].distance) * 0.5; + auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr); + for (const auto & hit : got_by_docid) { + printf("hit docid=%u dist=%g (thr %g)\n", hit.docid, hit.distance, thr); + } + EXPECT_EQ(got_by_docid.size(), 1); + EXPECT_EQ(got_by_docid[0].docid, rv[0].docid); + } } }; |