diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-04-25 11:51:03 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-04-25 15:56:19 +0000 |
commit | 6201b1979bbc7fd0c156bf77af855b09237a7045 (patch) | |
tree | 2739897b18d2b4d061eea5f35d6211c692e4883d | |
parent | b552ce33c561eef8b4440bb2ddc93b24afb8d16a (diff) |
add BoundGeoDistance
4 files changed, 165 insertions, 101 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index e5bed3ebae5..74ef23f6eb8 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -18,14 +18,17 @@ using search::attribute::DistanceMetric; template <typename T> TypedCells t(const std::vector<T> &v) { return TypedCells(v); } -void verify_geo_miles(const DistanceFunction *dist_fun, - const std::vector<double> &p1, +void verify_geo_miles(const std::vector<double> &p1, const std::vector<double> &p2, double exp_miles) { + static GeoDistanceFunctionFactory dff; TypedCells t1(p1); TypedCells t2(p2); - double abstract_distance = dist_fun->calc(t1, t2); + auto dist_fun = dff.for_query_vector(t1); + double abstract_distance = dist_fun->calc(t2); + EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance); + EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance); double raw_score = dist_fun->to_rawscore(abstract_distance); double km = ((1.0/raw_score)-1.0); double d_miles = km / 1.609344; @@ -443,9 +446,6 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) TEST(GeoDegreesTest, gives_expected_score) { - auto ct = vespalib::eval::CellType::DOUBLE; - auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct); - std::vector<double> g1_sfo{37.61, -122.38}; std::vector<double> g2_lhr{51.47, -0.46}; std::vector<double> g3_osl{60.20, 11.08}; @@ -456,7 +456,8 @@ TEST(GeoDegreesTest, gives_expected_score) std::vector<double> g8_lax{33.94, -118.41}; std::vector<double> g9_jfk{40.64, -73.78}; - double g63_a = geodeg->calc(t(g6_trd), t(g3_osl)); + auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd)); + double g63_a = geodeg->calc(t(g3_osl)); double g63_r = geodeg->to_rawscore(g63_a); double g63_km = ((1.0/g63_r)-1.0); EXPECT_GT(g63_km, 350); @@ -466,96 +467,95 @@ TEST(GeoDegreesTest, gives_expected_score) // Great Circle Mapper for airports using // a more accurate formula - we should agree // with < 1.0% deviation - verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0); - verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367); - verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196); - verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604); - verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927); - verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012); - verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417); - verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337); - verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586); - - verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367); - verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0); - verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750); - verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734); - verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994); - verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928); - verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573); - verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456); - verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451); - - verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196); - verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750); - verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0); - verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479); - verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319); - verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226); - verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888); - verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345); - verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687); - - verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604); - verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734); - verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479); - verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0); - verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989); - verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623); - verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414); - verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294); - verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786); - - verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927); - verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994); - verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319); - verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989); - verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0); - verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240); - verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581); - verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260); - verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072); - - verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012); - verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928); - verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226); - verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623); - verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240); - verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0); - verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782); - verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171); - verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611); - - verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417); - verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573); - verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888); - verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414); - verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581); - verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782); - verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0); - verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488); - verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950); - - verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337); - verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456); - verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345); - verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294); - verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260); - verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171); - verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488); - verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0); - verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475); - - verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586); - verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451); - verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687); - verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786); - verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072); - verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611); - verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950); - verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475); - verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0); - + verify_geo_miles(g1_sfo, g1_sfo, 0); + verify_geo_miles(g1_sfo, g2_lhr, 5367); + verify_geo_miles(g1_sfo, g3_osl, 5196); + verify_geo_miles(g1_sfo, g4_gig, 6604); + verify_geo_miles(g1_sfo, g5_hkg, 6927); + verify_geo_miles(g1_sfo, g6_trd, 5012); + verify_geo_miles(g1_sfo, g7_syd, 7417); + verify_geo_miles(g1_sfo, g8_lax, 337); + verify_geo_miles(g1_sfo, g9_jfk, 2586); + + verify_geo_miles(g2_lhr, g1_sfo, 5367); + verify_geo_miles(g2_lhr, g2_lhr, 0); + verify_geo_miles(g2_lhr, g3_osl, 750); + verify_geo_miles(g2_lhr, g4_gig, 5734); + verify_geo_miles(g2_lhr, g5_hkg, 5994); + verify_geo_miles(g2_lhr, g6_trd, 928); + verify_geo_miles(g2_lhr, g7_syd, 10573); + verify_geo_miles(g2_lhr, g8_lax, 5456); + verify_geo_miles(g2_lhr, g9_jfk, 3451); + + verify_geo_miles(g3_osl, g1_sfo, 5196); + verify_geo_miles(g3_osl, g2_lhr, 750); + verify_geo_miles(g3_osl, g3_osl, 0); + verify_geo_miles(g3_osl, g4_gig, 6479); + verify_geo_miles(g3_osl, g5_hkg, 5319); + verify_geo_miles(g3_osl, g6_trd, 226); + verify_geo_miles(g3_osl, g7_syd, 9888); + verify_geo_miles(g3_osl, g8_lax, 5345); + verify_geo_miles(g3_osl, g9_jfk, 3687); + + verify_geo_miles(g4_gig, g1_sfo, 6604); + verify_geo_miles(g4_gig, g2_lhr, 5734); + verify_geo_miles(g4_gig, g3_osl, 6479); + verify_geo_miles(g4_gig, g4_gig, 0); + verify_geo_miles(g4_gig, g5_hkg, 10989); + verify_geo_miles(g4_gig, g6_trd, 6623); + verify_geo_miles(g4_gig, g7_syd, 8414); + verify_geo_miles(g4_gig, g8_lax, 6294); + verify_geo_miles(g4_gig, g9_jfk, 4786); + + verify_geo_miles(g5_hkg, g1_sfo, 6927); + verify_geo_miles(g5_hkg, g2_lhr, 5994); + verify_geo_miles(g5_hkg, g3_osl, 5319); + verify_geo_miles(g5_hkg, g4_gig, 10989); + verify_geo_miles(g5_hkg, g5_hkg, 0); + verify_geo_miles(g5_hkg, g6_trd, 5240); + verify_geo_miles(g5_hkg, g7_syd, 4581); + verify_geo_miles(g5_hkg, g8_lax, 7260); + verify_geo_miles(g5_hkg, g9_jfk, 8072); + + verify_geo_miles(g6_trd, g1_sfo, 5012); + verify_geo_miles(g6_trd, g2_lhr, 928); + verify_geo_miles(g6_trd, g3_osl, 226); + verify_geo_miles(g6_trd, g4_gig, 6623); + verify_geo_miles(g6_trd, g5_hkg, 5240); + verify_geo_miles(g6_trd, g6_trd, 0); + verify_geo_miles(g6_trd, g7_syd, 9782); + verify_geo_miles(g6_trd, g8_lax, 5171); + verify_geo_miles(g6_trd, g9_jfk, 3611); + + verify_geo_miles(g7_syd, g1_sfo, 7417); + verify_geo_miles(g7_syd, g2_lhr, 10573); + verify_geo_miles(g7_syd, g3_osl, 9888); + verify_geo_miles(g7_syd, g4_gig, 8414); + verify_geo_miles(g7_syd, g5_hkg, 4581); + verify_geo_miles(g7_syd, g6_trd, 9782); + verify_geo_miles(g7_syd, g7_syd, 0); + verify_geo_miles(g7_syd, g8_lax, 7488); + verify_geo_miles(g7_syd, g9_jfk, 9950); + + verify_geo_miles(g8_lax, g1_sfo, 337); + verify_geo_miles(g8_lax, g2_lhr, 5456); + verify_geo_miles(g8_lax, g3_osl, 5345); + verify_geo_miles(g8_lax, g4_gig, 6294); + verify_geo_miles(g8_lax, g5_hkg, 7260); + verify_geo_miles(g8_lax, g6_trd, 5171); + verify_geo_miles(g8_lax, g7_syd, 7488); + verify_geo_miles(g8_lax, g8_lax, 0); + verify_geo_miles(g8_lax, g9_jfk, 2475); + + verify_geo_miles(g9_jfk, g1_sfo, 2586); + verify_geo_miles(g9_jfk, g2_lhr, 3451); + verify_geo_miles(g9_jfk, g3_osl, 3687); + verify_geo_miles(g9_jfk, g4_gig, 4786); + verify_geo_miles(g9_jfk, g5_hkg, 8072); + verify_geo_miles(g9_jfk, g6_trd, 3611); + verify_geo_miles(g9_jfk, g7_syd, 9950); + verify_geo_miles(g9_jfk, g8_lax, 2475); + verify_geo_miles(g9_jfk, g9_jfk, 0); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 4553f39a525..5473d7db6f5 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -113,10 +113,10 @@ make_distance_function_factory(search::attribute::DistanceMetric variant, } return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>(); } - /* if (variant == DistanceMetric::GeoDegrees) { return std::make_unique<GeoDistanceFunctionFactory>(); } + /* if (variant == DistanceMetric::Hamming) { return std::make_unique<HammingDistanceFunctionFactory>(); } diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp index bcce75da3ab..973d50ef98d 100644 --- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "geo_degrees_distance.h" +#include "temporary_vector_store.h" using vespalib::typify_invoke; using vespalib::eval::TypifyCellType; @@ -27,11 +28,11 @@ struct CalcGeoDegrees { double lat_diff = lat_A - lat_B; double lon_diff = lon_A - lon_B; - + // haversines of differences: double hav_lat = GeoDegreesDistance::hav(lat_diff); double hav_lon = GeoDegreesDistance::hav(lon_diff); - + // haversine of central angle between the two points: double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; return hav_central_angle; @@ -42,9 +43,64 @@ struct CalcGeoDegrees { double GeoDegreesDistance::calc(const vespalib::eval::TypedCells& lhs, - const vespalib::eval::TypedCells& rhs) const + const vespalib::eval::TypedCells& rhs) const { return typify_invoke<2,TypifyCellType,CalcGeoDegrees>(lhs.type, rhs.type, lhs, rhs); } +using vespalib::eval::TypedCells; + +class BoundGeoDistance : public BoundDistanceFunction { +private: + mutable TemporaryVectorStore<double> _tmpSpace; + const vespalib::ConstArrayRef<double> _lh_vector; + static GeoDegreesDistance _g_d_helper; +public: + BoundGeoDistance(const vespalib::eval::TypedCells& lhs) + : BoundDistanceFunction(vespalib::eval::CellType::DOUBLE), + _tmpSpace(lhs.size), + _lh_vector(_tmpSpace.storeLhs(lhs)) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + vespalib::ConstArrayRef<double> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(2 == _lh_vector.size()); + assert(2 == rhs_vector.size()); + // convert to radians: + double lat_A = _lh_vector[0] * GeoDegreesDistance::degrees_to_radians; + double lat_B = rhs_vector[0] * GeoDegreesDistance::degrees_to_radians; + double lon_A = _lh_vector[1] * GeoDegreesDistance::degrees_to_radians; + double lon_B = rhs_vector[1] * GeoDegreesDistance::degrees_to_radians; + + double lat_diff = lat_A - lat_B; + double lon_diff = lon_A - lon_B; + + // haversines of differences: + double hav_lat = GeoDegreesDistance::hav(lat_diff); + double hav_lon = GeoDegreesDistance::hav(lon_diff); + + // haversine of central angle between the two points: + double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; + return hav_central_angle; + } + double convert_threshold(double threshold) const override { + return _g_d_helper.convert_threshold(threshold); + } + double to_rawscore(double distance) const override { + return _g_d_helper.to_rawscore(distance); + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + return calc(rhs); + } +}; + +BoundDistanceFunction::UP +GeoDistanceFunctionFactory::for_query_vector(const vespalib::eval::TypedCells& lhs) { + return std::make_unique<BoundGeoDistance>(lhs); +} + +BoundDistanceFunction::UP +GeoDistanceFunctionFactory::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + return std::make_unique<BoundGeoDistance>(lhs); +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h index 46feee19119..4522bc03c9e 100644 --- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include <vespa/eval/eval/typed_cells.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> #include <vespa/vespalib/util/typify.h> @@ -50,4 +51,11 @@ public: } }; +class GeoDistanceFunctionFactory : public DistanceFunctionFactory { +public: + GeoDistanceFunctionFactory() : DistanceFunctionFactory(vespalib::eval::CellType::DOUBLE) {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + } |