summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/distance_functions
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-20 13:26:02 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-24 07:58:38 +0000
commit06ae8f17e040fe27e19dd7bdf4857ce0c4ccaba1 (patch)
tree95635dda48f7f188f87e4b62ee2852a7cfb652c3 /searchlib/src/tests/tensor/distance_functions
parentbc7e1b8c1a96a49838e55a7b97c4de565abc9ad3 (diff)
also add BoundEuclideanDistance
Diffstat (limited to 'searchlib/src/tests/tensor/distance_functions')
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp82
1 files changed, 73 insertions, 9 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 86b83b2c651..ae283f3f2b2 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -44,6 +44,30 @@ void verify_geo_miles(const DistanceFunction *dist_fun,
}
}
+double computeEuclideanChecked(TypedCells a, TypedCells b) {
+ static EuclideanDistanceFunctionFactory<Int8Float> i8f_dff;
+ static EuclideanDistanceFunctionFactory<float> flt_dff;
+ static EuclideanDistanceFunctionFactory<double> dbl_dff;
+ auto d_n = dbl_dff.for_query_vector(a);
+ auto d_f = flt_dff.for_query_vector(a);
+ auto d_r = dbl_dff.for_query_vector(b);
+ auto d_i = dbl_dff.for_insertion_vector(a);
+ // normal:
+ double result = d_n->calc(b);
+ // insert is exactly same:
+ EXPECT_EQ(d_i->calc(b), result);
+ // reverse:
+ EXPECT_DOUBLE_EQ(d_r->calc(a), result);
+ // float factory:
+ EXPECT_FLOAT_EQ(d_f->calc(b), result);
+ if (a.type == vespalib::eval::CellType::INT8 ||
+ b.type == vespalib::eval::CellType::INT8)
+ {
+ auto d_8 = i8f_dff.for_query_vector(a);
+ EXPECT_DOUBLE_EQ(d_8->calc(b), result);
+ }
+ return result;
+}
TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
{
@@ -59,15 +83,56 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
- double n4 = euclid->calc(t(p0), t(p4));
+ double n4 = computeEuclideanChecked(t(p0), t(p4));
EXPECT_FLOAT_EQ(n4, 1.0);
- double d12 = euclid->calc(t(p1), t(p2));
+ double d12 = computeEuclideanChecked(t(p1), t(p2));
EXPECT_EQ(d12, 2.0);
EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0)));
double threshold = euclid->convert_threshold(8.0);
EXPECT_EQ(threshold, 64.0);
threshold = euclid->convert_threshold(0.5);
EXPECT_EQ(threshold, 0.25);
+
+ // simple hand-checked distances:
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p0)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p1)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p2)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p3)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p5)), 1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p0), t(p6)), 9.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p1)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p1), t(p6)), 8.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p2)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p3)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p5)), 4.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p2), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p3)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p5)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p3), t(p6)), 6.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p5)), 0.0);
+ EXPECT_EQ(computeEuclideanChecked(t(p5), t(p6)), 14.0);
+
+ EXPECT_EQ(computeEuclideanChecked(t(p6), t(p6)), 0.0);
+
+ // smoke test for bfloat16:
+ std::vector<vespalib::BFloat16> bf16v;
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ bf16v.emplace_back(1.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p0)), 3.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p1)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p2)), 2.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p3)), 2.0);
+ EXPECT_FLOAT_EQ(computeEuclideanChecked(t(bf16v), t(p4)), 0.5857863);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p5)), 6.0);
+ EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p6)), 2.0);
}
TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
@@ -81,14 +146,13 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest)
std::vector<Int8Float> p5{0.0,-1.0, 0.0};
std::vector<Int8Float> p7{-1.0, 2.0, -2.0};
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p1)));
- EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p5)));
- EXPECT_DOUBLE_EQ(9.0, euclid->calc(t(p0), t(p7)));
-
- EXPECT_DOUBLE_EQ(2.0, euclid->calc(t(p1), t(p5)));
- EXPECT_DOUBLE_EQ(12.0, euclid->calc(t(p1), t(p7)));
- EXPECT_DOUBLE_EQ(14.0, euclid->calc(t(p5), t(p7)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p1)));
+ EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p5)));
+ EXPECT_DOUBLE_EQ(9.0, computeEuclideanChecked(t(p0), t(p7)));
+ EXPECT_DOUBLE_EQ(2.0, computeEuclideanChecked(t(p1), t(p5)));
+ EXPECT_DOUBLE_EQ(12.0, computeEuclideanChecked(t(p1), t(p7)));
+ EXPECT_DOUBLE_EQ(14.0, computeEuclideanChecked(t(p5), t(p7)));
}
double computeAngularChecked(TypedCells a, TypedCells b) {