From 06ae8f17e040fe27e19dd7bdf4857ce0c4ccaba1 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 20 Apr 2023 13:26:02 +0000 Subject: also add BoundEuclideanDistance --- .../distance_functions/distance_functions_test.cpp | 82 +++++++++++++++++++--- 1 file changed, 73 insertions(+), 9 deletions(-) (limited to 'searchlib/src/tests/tensor/distance_functions') diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 86b83b2c651..ae283f3f2b2 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -44,6 +44,30 @@ void verify_geo_miles(const DistanceFunction *dist_fun, } } +double computeEuclideanChecked(TypedCells a, TypedCells b) { + static EuclideanDistanceFunctionFactory i8f_dff; + static EuclideanDistanceFunctionFactory flt_dff; + static EuclideanDistanceFunctionFactory dbl_dff; + auto d_n = dbl_dff.for_query_vector(a); + auto d_f = flt_dff.for_query_vector(a); + auto d_r = dbl_dff.for_query_vector(b); + auto d_i = dbl_dff.for_insertion_vector(a); + // normal: + double result = d_n->calc(b); + // insert is exactly same: + EXPECT_EQ(d_i->calc(b), result); + // reverse: + EXPECT_DOUBLE_EQ(d_r->calc(a), result); + // float factory: + EXPECT_FLOAT_EQ(d_f->calc(b), result); + if (a.type == vespalib::eval::CellType::INT8 || + b.type == vespalib::eval::CellType::INT8) + { + auto d_8 = i8f_dff.for_query_vector(a); + EXPECT_DOUBLE_EQ(d_8->calc(b), result); + } + return result; +} TEST(DistanceFunctionsTest, euclidean_gives_expected_score) { @@ -59,15 +83,56 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) std::vector p5{0.0,-1.0, 0.0}; std::vector p6{1.0, 2.0, 2.0}; - double n4 = euclid->calc(t(p0), t(p4)); + double n4 = computeEuclideanChecked(t(p0), t(p4)); EXPECT_FLOAT_EQ(n4, 1.0); - double d12 = euclid->calc(t(p1), t(p2)); + double d12 = computeEuclideanChecked(t(p1), t(p2)); EXPECT_EQ(d12, 2.0); EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0))); double threshold = euclid->convert_threshold(8.0); EXPECT_EQ(threshold, 64.0); threshold = euclid->convert_threshold(0.5); EXPECT_EQ(threshold, 0.25); + + // simple hand-checked distances: + EXPECT_EQ(computeEuclideanChecked(t(p0), t(p0)), 0.0); + EXPECT_EQ(computeEuclideanChecked(t(p0), t(p1)), 1.0); + EXPECT_EQ(computeEuclideanChecked(t(p0), t(p2)), 1.0); + EXPECT_EQ(computeEuclideanChecked(t(p0), t(p3)), 1.0); + EXPECT_EQ(computeEuclideanChecked(t(p0), t(p5)), 1.0); + EXPECT_EQ(computeEuclideanChecked(t(p0), t(p6)), 9.0); + + EXPECT_EQ(computeEuclideanChecked(t(p1), t(p1)), 0.0); + EXPECT_EQ(computeEuclideanChecked(t(p1), t(p2)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(p1), t(p3)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(p1), t(p5)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(p1), t(p6)), 8.0); + + EXPECT_EQ(computeEuclideanChecked(t(p2), t(p2)), 0.0); + EXPECT_EQ(computeEuclideanChecked(t(p2), t(p3)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(p2), t(p5)), 4.0); + EXPECT_EQ(computeEuclideanChecked(t(p2), t(p6)), 6.0); + + EXPECT_EQ(computeEuclideanChecked(t(p3), t(p3)), 0.0); + EXPECT_EQ(computeEuclideanChecked(t(p3), t(p5)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(p3), t(p6)), 6.0); + + EXPECT_EQ(computeEuclideanChecked(t(p5), t(p5)), 0.0); + EXPECT_EQ(computeEuclideanChecked(t(p5), t(p6)), 14.0); + + EXPECT_EQ(computeEuclideanChecked(t(p6), t(p6)), 0.0); + + // smoke test for bfloat16: + std::vector bf16v; + bf16v.emplace_back(1.0); + bf16v.emplace_back(1.0); + bf16v.emplace_back(1.0); + EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p0)), 3.0); + EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p1)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p2)), 2.0); + EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p3)), 2.0); + EXPECT_FLOAT_EQ(computeEuclideanChecked(t(bf16v), t(p4)), 0.5857863); + EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p5)), 6.0); + EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p6)), 2.0); } TEST(DistanceFunctionsTest, euclidean_int8_smoketest) @@ -81,14 +146,13 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest) std::vector p5{0.0,-1.0, 0.0}; std::vector p7{-1.0, 2.0, -2.0}; - EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p1))); - EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p0), t(p5))); - EXPECT_DOUBLE_EQ(9.0, euclid->calc(t(p0), t(p7))); - - EXPECT_DOUBLE_EQ(2.0, euclid->calc(t(p1), t(p5))); - EXPECT_DOUBLE_EQ(12.0, euclid->calc(t(p1), t(p7))); - EXPECT_DOUBLE_EQ(14.0, euclid->calc(t(p5), t(p7))); + EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p1))); + EXPECT_DOUBLE_EQ(1.0, computeEuclideanChecked(t(p0), t(p5))); + EXPECT_DOUBLE_EQ(9.0, computeEuclideanChecked(t(p0), t(p7))); + EXPECT_DOUBLE_EQ(2.0, computeEuclideanChecked(t(p1), t(p5))); + EXPECT_DOUBLE_EQ(12.0, computeEuclideanChecked(t(p1), t(p7))); + EXPECT_DOUBLE_EQ(14.0, computeEuclideanChecked(t(p5), t(p7))); } double computeAngularChecked(TypedCells a, TypedCells b) { -- cgit v1.2.3