diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-04-26 13:59:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-26 13:59:06 +0200 |
commit | 797091a6867be9543c6f1d08f0189cbe7c12e0b3 (patch) | |
tree | 5fbb7b6277c31a1ed46532108c28c7f1d3f74fda /searchlib/src/tests/tensor/distance_functions | |
parent | 01cc25458c74d2902879087919f67622600ffc65 (diff) | |
parent | b2401a91381d1f66ef316d850d469181f06f0d36 (diff) |
Merge pull request #26849 from vespa-engine/arnej/add-bound-hamming
add bound hamming, geo distance
Diffstat (limited to 'searchlib/src/tests/tensor/distance_functions')
-rw-r--r-- | searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp | 200 |
1 files changed, 103 insertions, 97 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 600c5ae0646..9b8ad0d26ce 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -18,14 +18,17 @@ using search::attribute::DistanceMetric; template <typename T> TypedCells t(const std::vector<T> &v) { return TypedCells(v); } -void verify_geo_miles(const DistanceFunction *dist_fun, - const std::vector<double> &p1, +void verify_geo_miles(const std::vector<double> &p1, const std::vector<double> &p2, double exp_miles) { + static GeoDistanceFunctionFactory dff; TypedCells t1(p1); TypedCells t2(p2); - double abstract_distance = dist_fun->calc(t1, t2); + auto dist_fun = dff.for_query_vector(t1); + double abstract_distance = dist_fun->calc(t2); + EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance); + EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance); double raw_score = dist_fun->to_rawscore(abstract_distance); double km = ((1.0/raw_score)-1.0); double d_miles = km / 1.609344; @@ -391,6 +394,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) TEST(DistanceFunctionsTest, hamming_gives_expected_score) { + static HammingDistanceFunctionFactory<Int8Float> dff; auto ct = vespalib::eval::CellType::DOUBLE; auto hamming = make_distance_function(DistanceMetric::Hamming, ct); @@ -407,6 +411,9 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double h0 = hamming->calc(t(p), t(p)); EXPECT_EQ(h0, 0.0); EXPECT_EQ(hamming->to_rawscore(h0), 1.0); + auto dist_fun = dff.for_query_vector(t(p)); + EXPECT_EQ(dist_fun->calc(t(p)), 0.0); + EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0); } double d12 = hamming->calc(t(points[1]), t(points[2])); EXPECT_EQ(d12, 3.0); @@ -439,13 +446,12 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 }; // expect diff: 1 2 1 1 7 EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0); + auto dist_fun = dff.for_query_vector(TypedCells(bytes_a)); + EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0); } TEST(GeoDegreesTest, gives_expected_score) { - auto ct = vespalib::eval::CellType::DOUBLE; - auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct); - std::vector<double> g1_sfo{37.61, -122.38}; std::vector<double> g2_lhr{51.47, -0.46}; std::vector<double> g3_osl{60.20, 11.08}; @@ -456,7 +462,8 @@ TEST(GeoDegreesTest, gives_expected_score) std::vector<double> g8_lax{33.94, -118.41}; std::vector<double> g9_jfk{40.64, -73.78}; - double g63_a = geodeg->calc(t(g6_trd), t(g3_osl)); + auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd)); + double g63_a = geodeg->calc(t(g3_osl)); double g63_r = geodeg->to_rawscore(g63_a); double g63_km = ((1.0/g63_r)-1.0); EXPECT_GT(g63_km, 350); @@ -466,96 +473,95 @@ TEST(GeoDegreesTest, gives_expected_score) // Great Circle Mapper for airports using // a more accurate formula - we should agree // with < 1.0% deviation - verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0); - verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367); - verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196); - verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604); - verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927); - verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012); - verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417); - verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337); - verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586); - - verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367); - verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0); - verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750); - verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734); - verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994); - verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928); - verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573); - verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456); - verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451); - - verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196); - verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750); - verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0); - verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479); - verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319); - verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226); - verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888); - verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345); - verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687); - - verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604); - verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734); - verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479); - verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0); - verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989); - verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623); - verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414); - verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294); - verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786); - - verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927); - verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994); - verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319); - verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989); - verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0); - verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240); - verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581); - verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260); - verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072); - - verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012); - verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928); - verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226); - verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623); - verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240); - verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0); - verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782); - verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171); - verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611); - - verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417); - verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573); - verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888); - verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414); - verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581); - verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782); - verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0); - verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488); - verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950); - - verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337); - verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456); - verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345); - verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294); - verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260); - verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171); - verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488); - verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0); - verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475); - - verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586); - verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451); - verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687); - verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786); - verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072); - verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611); - verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950); - verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475); - verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0); - + verify_geo_miles(g1_sfo, g1_sfo, 0); + verify_geo_miles(g1_sfo, g2_lhr, 5367); + verify_geo_miles(g1_sfo, g3_osl, 5196); + verify_geo_miles(g1_sfo, g4_gig, 6604); + verify_geo_miles(g1_sfo, g5_hkg, 6927); + verify_geo_miles(g1_sfo, g6_trd, 5012); + verify_geo_miles(g1_sfo, g7_syd, 7417); + verify_geo_miles(g1_sfo, g8_lax, 337); + verify_geo_miles(g1_sfo, g9_jfk, 2586); + + verify_geo_miles(g2_lhr, g1_sfo, 5367); + verify_geo_miles(g2_lhr, g2_lhr, 0); + verify_geo_miles(g2_lhr, g3_osl, 750); + verify_geo_miles(g2_lhr, g4_gig, 5734); + verify_geo_miles(g2_lhr, g5_hkg, 5994); + verify_geo_miles(g2_lhr, g6_trd, 928); + verify_geo_miles(g2_lhr, g7_syd, 10573); + verify_geo_miles(g2_lhr, g8_lax, 5456); + verify_geo_miles(g2_lhr, g9_jfk, 3451); + + verify_geo_miles(g3_osl, g1_sfo, 5196); + verify_geo_miles(g3_osl, g2_lhr, 750); + verify_geo_miles(g3_osl, g3_osl, 0); + verify_geo_miles(g3_osl, g4_gig, 6479); + verify_geo_miles(g3_osl, g5_hkg, 5319); + verify_geo_miles(g3_osl, g6_trd, 226); + verify_geo_miles(g3_osl, g7_syd, 9888); + verify_geo_miles(g3_osl, g8_lax, 5345); + verify_geo_miles(g3_osl, g9_jfk, 3687); + + verify_geo_miles(g4_gig, g1_sfo, 6604); + verify_geo_miles(g4_gig, g2_lhr, 5734); + verify_geo_miles(g4_gig, g3_osl, 6479); + verify_geo_miles(g4_gig, g4_gig, 0); + verify_geo_miles(g4_gig, g5_hkg, 10989); + verify_geo_miles(g4_gig, g6_trd, 6623); + verify_geo_miles(g4_gig, g7_syd, 8414); + verify_geo_miles(g4_gig, g8_lax, 6294); + verify_geo_miles(g4_gig, g9_jfk, 4786); + + verify_geo_miles(g5_hkg, g1_sfo, 6927); + verify_geo_miles(g5_hkg, g2_lhr, 5994); + verify_geo_miles(g5_hkg, g3_osl, 5319); + verify_geo_miles(g5_hkg, g4_gig, 10989); + verify_geo_miles(g5_hkg, g5_hkg, 0); + verify_geo_miles(g5_hkg, g6_trd, 5240); + verify_geo_miles(g5_hkg, g7_syd, 4581); + verify_geo_miles(g5_hkg, g8_lax, 7260); + verify_geo_miles(g5_hkg, g9_jfk, 8072); + + verify_geo_miles(g6_trd, g1_sfo, 5012); + verify_geo_miles(g6_trd, g2_lhr, 928); + verify_geo_miles(g6_trd, g3_osl, 226); + verify_geo_miles(g6_trd, g4_gig, 6623); + verify_geo_miles(g6_trd, g5_hkg, 5240); + verify_geo_miles(g6_trd, g6_trd, 0); + verify_geo_miles(g6_trd, g7_syd, 9782); + verify_geo_miles(g6_trd, g8_lax, 5171); + verify_geo_miles(g6_trd, g9_jfk, 3611); + + verify_geo_miles(g7_syd, g1_sfo, 7417); + verify_geo_miles(g7_syd, g2_lhr, 10573); + verify_geo_miles(g7_syd, g3_osl, 9888); + verify_geo_miles(g7_syd, g4_gig, 8414); + verify_geo_miles(g7_syd, g5_hkg, 4581); + verify_geo_miles(g7_syd, g6_trd, 9782); + verify_geo_miles(g7_syd, g7_syd, 0); + verify_geo_miles(g7_syd, g8_lax, 7488); + verify_geo_miles(g7_syd, g9_jfk, 9950); + + verify_geo_miles(g8_lax, g1_sfo, 337); + verify_geo_miles(g8_lax, g2_lhr, 5456); + verify_geo_miles(g8_lax, g3_osl, 5345); + verify_geo_miles(g8_lax, g4_gig, 6294); + verify_geo_miles(g8_lax, g5_hkg, 7260); + verify_geo_miles(g8_lax, g6_trd, 5171); + verify_geo_miles(g8_lax, g7_syd, 7488); + verify_geo_miles(g8_lax, g8_lax, 0); + verify_geo_miles(g8_lax, g9_jfk, 2475); + + verify_geo_miles(g9_jfk, g1_sfo, 2586); + verify_geo_miles(g9_jfk, g2_lhr, 3451); + verify_geo_miles(g9_jfk, g3_osl, 3687); + verify_geo_miles(g9_jfk, g4_gig, 4786); + verify_geo_miles(g9_jfk, g5_hkg, 8072); + verify_geo_miles(g9_jfk, g6_trd, 3611); + verify_geo_miles(g9_jfk, g7_syd, 9950); + verify_geo_miles(g9_jfk, g8_lax, 2475); + verify_geo_miles(g9_jfk, g9_jfk, 0); } GTEST_MAIN_RUN_ALL_TESTS() |