summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp')
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp84
1 files changed, 14 insertions, 70 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 9b8ad0d26ce..a1b29c90986 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -76,10 +76,6 @@ namespace { const double sq_root_half = std::sqrt(0.5); }
TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
{
- auto ct = vespalib::eval::CellType::DOUBLE;
-
- auto euclid = make_distance_function(DistanceMetric::Euclidean, ct);
-
std::vector<double> p0{0.0, 0.0, 0.0};
std::vector<double> p1{1.0, 0.0, 0.0};
std::vector<double> p2{0.0, 1.0, 0.0};
@@ -92,6 +88,9 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
EXPECT_FLOAT_EQ(n4, 1.0);
double d12 = computeEuclideanChecked(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
+
+ EuclideanDistanceFunctionFactory<double> dff;
+ auto euclid = dff.for_query_vector(t(p0));
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
double threshold = euclid->convert_threshold(8.0);
EXPECT_EQ(threshold, 64.0);
@@ -142,10 +141,6 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
{
- auto ct = vespalib::eval::CellType::INT8;
-
- auto euclid = make_distance_function(DistanceMetric::Euclidean, ct);
-
std::vector<Int8Float> p0{0.0, 0.0, 0.0};
std::vector<Int8Float> p1{1.0, 0.0, 0.0};
std::vector<Int8Float> p5{0.0,-1.0, 0.0};
@@ -345,60 +340,10 @@ TEST(DistanceFunctionsTest, prenormalized_angular_gives_expected_score)
EXPECT_DOUBLE_EQ(threshold, 1.0);
}
-TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
-{
- auto ct = vespalib::eval::CellType::DOUBLE;
-
- auto innerproduct = make_distance_function(DistanceMetric::InnerProduct, ct);
-
- std::vector<double> p0{0.0, 0.0, 0.0};
- std::vector<double> p1{1.0, 0.0, 0.0};
- std::vector<double> p2{0.0, 1.0, 0.0};
- std::vector<double> p3{0.0, 0.0, 1.0};
- std::vector<double> p4{0.5, 0.5, sq_root_half};
- std::vector<double> p5{0.0,-1.0, 0.0};
- std::vector<double> p6{1.0, 2.0, 2.0};
-
- double i12 = innerproduct->calc(t(p1), t(p2));
- double i13 = innerproduct->calc(t(p1), t(p3));
- double i23 = innerproduct->calc(t(p2), t(p3));
- EXPECT_DOUBLE_EQ(i12, 1.0);
- EXPECT_DOUBLE_EQ(i13, 1.0);
- EXPECT_DOUBLE_EQ(i23, 1.0);
-
- double i14 = innerproduct->calc(t(p1), t(p4));
- double i24 = innerproduct->calc(t(p2), t(p4));
- EXPECT_DOUBLE_EQ(i14, 0.5);
- EXPECT_DOUBLE_EQ(i24, 0.5);
- double i34 = innerproduct->calc(t(p3), t(p4));
- EXPECT_FLOAT_EQ(i34, 1.0 - sq_root_half);
-
- double i25 = innerproduct->calc(t(p2), t(p5));
- EXPECT_DOUBLE_EQ(i25, 2.0);
-
- double i44 = innerproduct->calc(t(p4), t(p4));
- EXPECT_GE(i44, 0.0);
- EXPECT_LT(i44, 0.000001);
-
- double i66 = innerproduct->calc(t(p6), t(p6));
- EXPECT_GE(i66, 0.0);
- EXPECT_LT(i66, 0.000001);
-
- double threshold = innerproduct->convert_threshold(0.25);
- EXPECT_DOUBLE_EQ(threshold, 0.25);
- threshold = innerproduct->convert_threshold(0.5);
- EXPECT_DOUBLE_EQ(threshold, 0.5);
- threshold = innerproduct->convert_threshold(1.0);
- EXPECT_DOUBLE_EQ(threshold, 1.0);
-}
TEST(DistanceFunctionsTest, hamming_gives_expected_score)
{
- static HammingDistanceFunctionFactory<Int8Float> dff;
- auto ct = vespalib::eval::CellType::DOUBLE;
-
- auto hamming = make_distance_function(DistanceMetric::Hamming, ct);
-
+ static HammingDistanceFunctionFactory<double> dff;
std::vector<std::vector<double>>
points{{0.0, 0.0, 0.0},
{1.0, 0.0, 0.0},
@@ -407,31 +352,30 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
{0.5, 0.5, 0.5},
{0.0,-1.0, 1.0},
{1.0, 1.0, 1.0}};
+ auto hamming = dff.for_query_vector(t(points[0]));
for (const auto & p : points) {
- double h0 = hamming->calc(t(p), t(p));
- EXPECT_EQ(h0, 0.0);
- EXPECT_EQ(hamming->to_rawscore(h0), 1.0);
auto dist_fun = dff.for_query_vector(t(p));
- EXPECT_EQ(dist_fun->calc(t(p)), 0.0);
+ double h0 = dist_fun->calc(t(p));
+ EXPECT_EQ(h0, 0.0);
EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0);
}
- double d12 = hamming->calc(t(points[1]), t(points[2]));
+ double d12 = dff.for_query_vector(t(points[1]))->calc(t(points[2]));
EXPECT_EQ(d12, 3.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d12), 1.0/(1.0 + 3.0));
- double d16 = hamming->calc(t(points[1]), t(points[6]));
+ double d16 = dff.for_query_vector(t(points[1]))->calc(t(points[6]));
EXPECT_EQ(d16, 2.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d16), 1.0/(1.0 + 2.0));
- double d23 = hamming->calc(t(points[2]), t(points[3]));
+ double d23 = dff.for_query_vector(t(points[2]))->calc(t(points[3]));
EXPECT_EQ(d23, 3.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d23), 1.0/(1.0 + 3.0));
- double d24 = hamming->calc(t(points[2]), t(points[4]));
+ double d24 = dff.for_query_vector(t(points[2]))->calc(t(points[4]));
EXPECT_EQ(d24, 3.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d24), 1.0/(1.0 + 3.0));
- double d25 = hamming->calc(t(points[2]), t(points[5]));
+ double d25 = dff.for_query_vector(t(points[2]))->calc(t(points[5]));
EXPECT_EQ(d25, 1.0);
EXPECT_DOUBLE_EQ(hamming->to_rawscore(d25), 1.0/(1.0 + 1.0));
@@ -445,8 +389,8 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
std::vector<Int8Float> bytes_a = { 0, 1, 2, 4, 8, 16, 32, 64, -128, 0, 1, 2, 4, 8, 16, 32, 64, -128, 0, 1, 2 };
std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 };
// expect diff: 1 2 1 1 7
- EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0);
- auto dist_fun = dff.for_query_vector(TypedCells(bytes_a));
+ HammingDistanceFunctionFactory<Int8Float> factory_int8;
+ auto dist_fun = factory_int8.for_query_vector(TypedCells(bytes_a));
EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0);
}