aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp')
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp39
1 files changed, 30 insertions, 9 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 283a38ec95d..7d0f741e362 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -45,32 +45,53 @@ TEST(DistanceFunctionsTest, gives_expected_score)
std::vector<double> p3{0.0, 0.0, 1.0};
std::vector<double> p4{0.5, 0.5, 0.707107};
std::vector<double> p5{0.0,-1.0, 0.0};
+ std::vector<double> p6{1.0, 2.0, 2.0};
double n4 = euclid->calc(t(p0), t(p4));
- EXPECT_GT(n4, 0.99999);
- EXPECT_LT(n4, 1.00001);
+ EXPECT_FLOAT_EQ(n4, 1.0);
double d12 = euclid->calc(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
+ EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
+ constexpr double pi = 3.14159265358979323846;
double a12 = angular->calc(t(p1), t(p2));
double a13 = angular->calc(t(p1), t(p3));
double a23 = angular->calc(t(p2), t(p3));
- EXPECT_DOUBLE_EQ(a12, 0.5);
- EXPECT_DOUBLE_EQ(a13, 0.5);
- EXPECT_DOUBLE_EQ(a23, 0.5);
+ EXPECT_DOUBLE_EQ(a12, 1.0);
+ EXPECT_DOUBLE_EQ(a13, 1.0);
+ EXPECT_DOUBLE_EQ(a23, 1.0);
+ EXPECT_FLOAT_EQ(angular->to_rawscore(a12), 1.0/(1.0 + pi/2));
+
double a14 = angular->calc(t(p1), t(p4));
double a24 = angular->calc(t(p2), t(p4));
- EXPECT_FLOAT_EQ(a14, 0.25);
- EXPECT_FLOAT_EQ(a24, 0.25);
+ EXPECT_FLOAT_EQ(a14, 0.5);
+ EXPECT_FLOAT_EQ(a24, 0.5);
+ EXPECT_FLOAT_EQ(angular->to_rawscore(a14), 1.0/(1.0 + pi/3));
+
double a34 = angular->calc(t(p3), t(p4));
- EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107)*0.5);
+ EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107));
+ EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4));
double a25 = angular->calc(t(p2), t(p5));
- EXPECT_DOUBLE_EQ(a25, 1.0);
+ EXPECT_DOUBLE_EQ(a25, 2.0);
+ EXPECT_FLOAT_EQ(angular->to_rawscore(a25), 1.0/(1.0 + pi));
double a44 = angular->calc(t(p4), t(p4));
EXPECT_GE(a44, 0.0);
EXPECT_LT(a44, 0.000001);
+ EXPECT_FLOAT_EQ(angular->to_rawscore(a44), 1.0);
+
+ double a66 = angular->calc(t(p6), t(p6));
+ EXPECT_GE(a66, 0.0);
+ EXPECT_LT(a66, 0.000001);
+ EXPECT_FLOAT_EQ(angular->to_rawscore(a66), 1.0);
+
+ double a16 = angular->calc(t(p1), t(p6));
+ double a26 = angular->calc(t(p2), t(p6));
+ double a36 = angular->calc(t(p3), t(p6));
+ EXPECT_FLOAT_EQ(a16, 1.0 - (1.0/3.0));
+ EXPECT_FLOAT_EQ(a26, 1.0 - (2.0/3.0));
+ EXPECT_FLOAT_EQ(a36, 1.0 - (2.0/3.0));
double i12 = innerproduct->calc(t(p1), t(p2));
double i13 = innerproduct->calc(t(p1), t(p3));