summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-01-07 12:42:17 +0000
committerArne Juul <arnej@verizonmedia.com>2021-01-08 10:56:14 +0000
commitcf199f338efafad8c0af7de48094bd3d0037b96a (patch)
tree2c63ecee902b297006edb41a67925e721bfddff6 /searchlib/src/tests/tensor
parent8aa9ffda4324ddd5baff87be858063c6399a26ca (diff)
add distanceThreshold option for nearestNeighbor operator
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp36
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp11
2 files changed, 44 insertions, 3 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 06fb95089fd..ee0a2aff80e 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -24,10 +24,12 @@ void verify_geo_miles(const DistanceFunction *dist_fun,
TypedCells t2(p2);
double abstract_distance = dist_fun->calc(t1, t2);
double raw_score = dist_fun->to_rawscore(abstract_distance);
- double m = ((1.0/raw_score)-1.0);
- double d_miles = m / 1.609344;
+ double km = ((1.0/raw_score)-1.0);
+ double d_miles = km / 1.609344;
EXPECT_GE(d_miles, exp_miles*0.99);
EXPECT_LE(d_miles, exp_miles*1.01);
+ double threshold = dist_fun->convert_threshold(km);
+ EXPECT_DOUBLE_EQ(threshold, abstract_distance);
}
@@ -50,6 +52,10 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
double d12 = euclid->calc(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
+ double threshold = euclid->convert_threshold(8.0);
+ EXPECT_EQ(threshold, 64.0);
+ threshold = euclid->convert_threshold(0.5);
+ EXPECT_EQ(threshold, 0.25);
}
TEST(DistanceFunctionsTest, angular_gives_expected_score)
@@ -75,19 +81,28 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
EXPECT_DOUBLE_EQ(a23, 1.0);
EXPECT_FLOAT_EQ(angular->to_rawscore(a12), 1.0/(1.0 + pi/2));
+ double threshold = angular->convert_threshold(pi/2);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
+
double a14 = angular->calc(t(p1), t(p4));
double a24 = angular->calc(t(p2), t(p4));
EXPECT_FLOAT_EQ(a14, 0.5);
EXPECT_FLOAT_EQ(a24, 0.5);
EXPECT_FLOAT_EQ(angular->to_rawscore(a14), 1.0/(1.0 + pi/3));
+ threshold = angular->convert_threshold(pi/3);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
double a34 = angular->calc(t(p3), t(p4));
EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107));
EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4));
+ threshold = angular->convert_threshold(pi/4);
+ EXPECT_FLOAT_EQ(threshold, a34);
double a25 = angular->calc(t(p2), t(p5));
EXPECT_DOUBLE_EQ(a25, 2.0);
EXPECT_FLOAT_EQ(angular->to_rawscore(a25), 1.0/(1.0 + pi));
+ threshold = angular->convert_threshold(pi);
+ EXPECT_FLOAT_EQ(threshold, 2.0);
double a44 = angular->calc(t(p4), t(p4));
EXPECT_GE(a44, 0.0);
@@ -98,6 +113,8 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
EXPECT_GE(a66, 0.0);
EXPECT_LT(a66, 0.000001);
EXPECT_FLOAT_EQ(angular->to_rawscore(a66), 1.0);
+ threshold = angular->convert_threshold(0.0);
+ EXPECT_FLOAT_EQ(threshold, 0.0);
double a16 = angular->calc(t(p1), t(p6));
double a26 = angular->calc(t(p2), t(p6));
@@ -127,6 +144,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
EXPECT_DOUBLE_EQ(i12, 1.0);
EXPECT_DOUBLE_EQ(i13, 1.0);
EXPECT_DOUBLE_EQ(i23, 1.0);
+
double i14 = innerproduct->calc(t(p1), t(p4));
double i24 = innerproduct->calc(t(p2), t(p4));
EXPECT_DOUBLE_EQ(i14, 0.5);
@@ -140,6 +158,13 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
double i44 = innerproduct->calc(t(p4), t(p4));
EXPECT_GE(i44, 0.0);
EXPECT_LT(i44, 0.000001);
+
+ double threshold = innerproduct->convert_threshold(0.25);
+ EXPECT_DOUBLE_EQ(threshold, 0.25);
+ threshold = innerproduct->convert_threshold(0.5);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
+ threshold = innerproduct->convert_threshold(1.0);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
}
TEST(DistanceFunctionsTest, hamming_gives_expected_score)
@@ -180,6 +205,13 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
double d25 = hamming->calc(t(points[2]), t(points[5]));
EXPECT_EQ(d25, 1.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d25), 1.0/(1.0 + 1.0));
+
+ double threshold = hamming->convert_threshold(0.25);
+ EXPECT_DOUBLE_EQ(threshold, 0.25);
+ threshold = hamming->convert_threshold(0.5);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
+ threshold = hamming->convert_threshold(1.0);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
}
TEST(GeoDegreesTest, gives_expected_score)
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index acc157709c0..d081c299a43 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -144,11 +144,20 @@ public:
if (exp_hits.size() == k) {
std::vector<uint32_t> expected_by_docid = exp_hits;
std::sort(expected_by_docid.begin(), expected_by_docid.end());
- auto got_by_docid = index->find_top_k(k, qv, k);
+ auto got_by_docid = index->find_top_k(k, qv, k, 100100.25);
for (idx = 0; idx < k; ++idx) {
EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid);
}
}
+ if ((rv.size() > 1) && (rv[0].distance < rv[1].distance)) {
+ double thr = (rv[0].distance + rv[1].distance) * 0.5;
+ auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr);
+ for (const auto & hit : got_by_docid) {
+ printf("hit docid=%u dist=%g (thr %g)\n", hit.docid, hit.distance, thr);
+ }
+ EXPECT_EQ(got_by_docid.size(), 1);
+ EXPECT_EQ(got_by_docid[0].docid, rv[0].docid);
+ }
}
};