diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-04-26 13:59:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-26 13:59:06 +0200 |
commit | 797091a6867be9543c6f1d08f0189cbe7c12e0b3 (patch) | |
tree | 5fbb7b6277c31a1ed46532108c28c7f1d3f74fda | |
parent | 01cc25458c74d2902879087919f67622600ffc65 (diff) | |
parent | b2401a91381d1f66ef316d850d469181f06f0d36 (diff) |
Merge pull request #26849 from vespa-engine/arnej/add-bound-hamming
add bound hamming, geo distance
10 files changed, 273 insertions, 144 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 600c5ae0646..9b8ad0d26ce 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -18,14 +18,17 @@ using search::attribute::DistanceMetric; template <typename T> TypedCells t(const std::vector<T> &v) { return TypedCells(v); } -void verify_geo_miles(const DistanceFunction *dist_fun, - const std::vector<double> &p1, +void verify_geo_miles(const std::vector<double> &p1, const std::vector<double> &p2, double exp_miles) { + static GeoDistanceFunctionFactory dff; TypedCells t1(p1); TypedCells t2(p2); - double abstract_distance = dist_fun->calc(t1, t2); + auto dist_fun = dff.for_query_vector(t1); + double abstract_distance = dist_fun->calc(t2); + EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance); + EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance); double raw_score = dist_fun->to_rawscore(abstract_distance); double km = ((1.0/raw_score)-1.0); double d_miles = km / 1.609344; @@ -391,6 +394,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) TEST(DistanceFunctionsTest, hamming_gives_expected_score) { + static HammingDistanceFunctionFactory<Int8Float> dff; auto ct = vespalib::eval::CellType::DOUBLE; auto hamming = make_distance_function(DistanceMetric::Hamming, ct); @@ -407,6 +411,9 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double h0 = hamming->calc(t(p), t(p)); EXPECT_EQ(h0, 0.0); EXPECT_EQ(hamming->to_rawscore(h0), 1.0); + auto dist_fun = dff.for_query_vector(t(p)); + EXPECT_EQ(dist_fun->calc(t(p)), 0.0); + EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0); } double d12 = hamming->calc(t(points[1]), t(points[2])); EXPECT_EQ(d12, 3.0); @@ -439,13 +446,12 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 }; // expect diff: 1 2 1 1 7 EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0); + auto dist_fun = dff.for_query_vector(TypedCells(bytes_a)); + EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0); } TEST(GeoDegreesTest, gives_expected_score) { - auto ct = vespalib::eval::CellType::DOUBLE; - auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct); - std::vector<double> g1_sfo{37.61, -122.38}; std::vector<double> g2_lhr{51.47, -0.46}; std::vector<double> g3_osl{60.20, 11.08}; @@ -456,7 +462,8 @@ TEST(GeoDegreesTest, gives_expected_score) std::vector<double> g8_lax{33.94, -118.41}; std::vector<double> g9_jfk{40.64, -73.78}; - double g63_a = geodeg->calc(t(g6_trd), t(g3_osl)); + auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd)); + double g63_a = geodeg->calc(t(g3_osl)); double g63_r = geodeg->to_rawscore(g63_a); double g63_km = ((1.0/g63_r)-1.0); EXPECT_GT(g63_km, 350); @@ -466,96 +473,95 @@ TEST(GeoDegreesTest, gives_expected_score) // Great Circle Mapper for airports using // a more accurate formula - we should agree // with < 1.0% deviation - verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0); - verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367); - verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196); - verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604); - verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927); - verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012); - verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417); - verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337); - verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586); - - verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367); - verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0); - verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750); - verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734); - verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994); - verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928); - verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573); - verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456); - verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451); - - verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196); - verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750); - verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0); - verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479); - verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319); - verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226); - verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888); - verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345); - verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687); - - verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604); - verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734); - verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479); - verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0); - verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989); - verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623); - verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414); - verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294); - verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786); - - verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927); - verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994); - verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319); - verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989); - verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0); - verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240); - verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581); - verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260); - verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072); - - verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012); - verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928); - verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226); - verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623); - verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240); - verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0); - verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782); - verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171); - verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611); - - verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417); - verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573); - verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888); - verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414); - verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581); - verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782); - verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0); - verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488); - verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950); - - verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337); - verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456); - verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345); - verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294); - verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260); - verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171); - verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488); - verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0); - verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475); - - verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586); - verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451); - verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687); - verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786); - verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072); - verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611); - verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950); - verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475); - verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0); - + verify_geo_miles(g1_sfo, g1_sfo, 0); + verify_geo_miles(g1_sfo, g2_lhr, 5367); + verify_geo_miles(g1_sfo, g3_osl, 5196); + verify_geo_miles(g1_sfo, g4_gig, 6604); + verify_geo_miles(g1_sfo, g5_hkg, 6927); + verify_geo_miles(g1_sfo, g6_trd, 5012); + verify_geo_miles(g1_sfo, g7_syd, 7417); + verify_geo_miles(g1_sfo, g8_lax, 337); + verify_geo_miles(g1_sfo, g9_jfk, 2586); + + verify_geo_miles(g2_lhr, g1_sfo, 5367); + verify_geo_miles(g2_lhr, g2_lhr, 0); + verify_geo_miles(g2_lhr, g3_osl, 750); + verify_geo_miles(g2_lhr, g4_gig, 5734); + verify_geo_miles(g2_lhr, g5_hkg, 5994); + verify_geo_miles(g2_lhr, g6_trd, 928); + verify_geo_miles(g2_lhr, g7_syd, 10573); + verify_geo_miles(g2_lhr, g8_lax, 5456); + verify_geo_miles(g2_lhr, g9_jfk, 3451); + + verify_geo_miles(g3_osl, g1_sfo, 5196); + verify_geo_miles(g3_osl, g2_lhr, 750); + verify_geo_miles(g3_osl, g3_osl, 0); + verify_geo_miles(g3_osl, g4_gig, 6479); + verify_geo_miles(g3_osl, g5_hkg, 5319); + verify_geo_miles(g3_osl, g6_trd, 226); + verify_geo_miles(g3_osl, g7_syd, 9888); + verify_geo_miles(g3_osl, g8_lax, 5345); + verify_geo_miles(g3_osl, g9_jfk, 3687); + + verify_geo_miles(g4_gig, g1_sfo, 6604); + verify_geo_miles(g4_gig, g2_lhr, 5734); + verify_geo_miles(g4_gig, g3_osl, 6479); + verify_geo_miles(g4_gig, g4_gig, 0); + verify_geo_miles(g4_gig, g5_hkg, 10989); + verify_geo_miles(g4_gig, g6_trd, 6623); + verify_geo_miles(g4_gig, g7_syd, 8414); + verify_geo_miles(g4_gig, g8_lax, 6294); + verify_geo_miles(g4_gig, g9_jfk, 4786); + + verify_geo_miles(g5_hkg, g1_sfo, 6927); + verify_geo_miles(g5_hkg, g2_lhr, 5994); + verify_geo_miles(g5_hkg, g3_osl, 5319); + verify_geo_miles(g5_hkg, g4_gig, 10989); + verify_geo_miles(g5_hkg, g5_hkg, 0); + verify_geo_miles(g5_hkg, g6_trd, 5240); + verify_geo_miles(g5_hkg, g7_syd, 4581); + verify_geo_miles(g5_hkg, g8_lax, 7260); + verify_geo_miles(g5_hkg, g9_jfk, 8072); + + verify_geo_miles(g6_trd, g1_sfo, 5012); + verify_geo_miles(g6_trd, g2_lhr, 928); + verify_geo_miles(g6_trd, g3_osl, 226); + verify_geo_miles(g6_trd, g4_gig, 6623); + verify_geo_miles(g6_trd, g5_hkg, 5240); + verify_geo_miles(g6_trd, g6_trd, 0); + verify_geo_miles(g6_trd, g7_syd, 9782); + verify_geo_miles(g6_trd, g8_lax, 5171); + verify_geo_miles(g6_trd, g9_jfk, 3611); + + verify_geo_miles(g7_syd, g1_sfo, 7417); + verify_geo_miles(g7_syd, g2_lhr, 10573); + verify_geo_miles(g7_syd, g3_osl, 9888); + verify_geo_miles(g7_syd, g4_gig, 8414); + verify_geo_miles(g7_syd, g5_hkg, 4581); + verify_geo_miles(g7_syd, g6_trd, 9782); + verify_geo_miles(g7_syd, g7_syd, 0); + verify_geo_miles(g7_syd, g8_lax, 7488); + verify_geo_miles(g7_syd, g9_jfk, 9950); + + verify_geo_miles(g8_lax, g1_sfo, 337); + verify_geo_miles(g8_lax, g2_lhr, 5456); + verify_geo_miles(g8_lax, g3_osl, 5345); + verify_geo_miles(g8_lax, g4_gig, 6294); + verify_geo_miles(g8_lax, g5_hkg, 7260); + verify_geo_miles(g8_lax, g6_trd, 5171); + verify_geo_miles(g8_lax, g7_syd, 7488); + verify_geo_miles(g8_lax, g8_lax, 0); + verify_geo_miles(g8_lax, g9_jfk, 2475); + + verify_geo_miles(g9_jfk, g1_sfo, 2586); + verify_geo_miles(g9_jfk, g2_lhr, 3451); + verify_geo_miles(g9_jfk, g3_osl, 3687); + verify_geo_miles(g9_jfk, g4_gig, 4786); + verify_geo_miles(g9_jfk, g5_hkg, 8072); + verify_geo_miles(g9_jfk, g6_trd, 3611); + verify_geo_miles(g9_jfk, g7_syd, 9950); + verify_geo_miles(g9_jfk, g8_lax, 2475); + verify_geo_miles(g9_jfk, g9_jfk, 0); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp index efc1170baf5..a7ae02bb9f4 100644 --- a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp @@ -61,8 +61,7 @@ private: double _lhs_norm_sq; public: BoundAngularDistance(const vespalib::eval::TypedCells& lhs) - : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()), - _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), + : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), _tmpSpace(lhs.size), _lhs(_tmpSpace.storeLhs(lhs)) { diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h index 5d602a52227..c072d6de8e5 100644 --- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h +++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h @@ -20,20 +20,13 @@ namespace search::tensor { * mutable temporary storage. */ class BoundDistanceFunction : public DistanceConverter { -private: - vespalib::eval::CellType _expect_cell_type; public: using UP = std::unique_ptr<BoundDistanceFunction>; - BoundDistanceFunction(vespalib::eval::CellType expected) : _expect_cell_type(expected) {} + BoundDistanceFunction() = default; virtual ~BoundDistanceFunction() = default; - // input vectors will be converted to this cell type: - vespalib::eval::CellType expected_cell_type() const { - return _expect_cell_type; - } - // calculate internal distance (comparable) virtual double calc(const vespalib::eval::TypedCells& rhs) const = 0; diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 4553f39a525..c088d498f0f 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -55,8 +55,7 @@ class SimpleBoundDistanceFunction : public BoundDistanceFunction { public: SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs, const DistanceFunction &df) - : BoundDistanceFunction(lhs.type), - _lhs(lhs), + : _lhs(lhs), _df(df) {} @@ -94,35 +93,35 @@ std::unique_ptr<DistanceFunctionFactory> make_distance_function_factory(search::attribute::DistanceMetric variant, vespalib::eval::CellType cell_type) { - if (variant == DistanceMetric::Angular) { - if (cell_type == CellType::DOUBLE) { - return std::make_unique<AngularDistanceFunctionFactory<double>>(); - } - return std::make_unique<AngularDistanceFunctionFactory<float>>(); - } - if (variant == DistanceMetric::Euclidean) { - switch (cell_type) { - case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>(); - case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>(); - default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>(); - } - } - if (variant == DistanceMetric::PrenormalizedAngular) { - if (cell_type == CellType::DOUBLE) { - return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>(); - } - return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>(); - } - /* - if (variant == DistanceMetric::GeoDegrees) { - return std::make_unique<GeoDistanceFunctionFactory>(); - } - if (variant == DistanceMetric::Hamming) { - return std::make_unique<HammingDistanceFunctionFactory>(); + switch (variant) { + case DistanceMetric::Angular: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<AngularDistanceFunctionFactory<double>>(); + default: return std::make_unique<AngularDistanceFunctionFactory<float>>(); + } + case DistanceMetric::Euclidean: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>(); + case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>(); + default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>(); + } + case DistanceMetric::InnerProduct: + case DistanceMetric::PrenormalizedAngular: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>(); + default: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>(); + } + case DistanceMetric::GeoDegrees: + return std::make_unique<GeoDistanceFunctionFactory>(); + case DistanceMetric::Hamming: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<HammingDistanceFunctionFactory<double>>(); + case CellType::INT8: return std::make_unique<HammingDistanceFunctionFactory<vespalib::eval::Int8Float>>(); + default: return std::make_unique<HammingDistanceFunctionFactory<float>>(); + } } - */ - auto df = make_distance_function(variant, cell_type); - return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df)); + // not reached: + return {}; } } diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp index 9c37b191637..7995c87d055 100644 --- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp @@ -62,8 +62,7 @@ private: static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); } public: BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs) - : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()), - _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), + : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), _tmpSpace(lhs.size), _lhs_vector(_tmpSpace.storeLhs(lhs)) {} diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp index bcce75da3ab..38ba8205c90 100644 --- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "geo_degrees_distance.h" +#include "temporary_vector_store.h" using vespalib::typify_invoke; using vespalib::eval::TypifyCellType; @@ -27,11 +28,11 @@ struct CalcGeoDegrees { double lat_diff = lat_A - lat_B; double lon_diff = lon_A - lon_B; - + // haversines of differences: double hav_lat = GeoDegreesDistance::hav(lat_diff); double hav_lon = GeoDegreesDistance::hav(lon_diff); - + // haversine of central angle between the two points: double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; return hav_central_angle; @@ -42,9 +43,63 @@ struct CalcGeoDegrees { double GeoDegreesDistance::calc(const vespalib::eval::TypedCells& lhs, - const vespalib::eval::TypedCells& rhs) const + const vespalib::eval::TypedCells& rhs) const { return typify_invoke<2,TypifyCellType,CalcGeoDegrees>(lhs.type, rhs.type, lhs, rhs); } +using vespalib::eval::TypedCells; + +class BoundGeoDistance : public BoundDistanceFunction { +private: + mutable TemporaryVectorStore<double> _tmpSpace; + const vespalib::ConstArrayRef<double> _lh_vector; + static GeoDegreesDistance _g_d_helper; +public: + BoundGeoDistance(const vespalib::eval::TypedCells& lhs) + : _tmpSpace(lhs.size), + _lh_vector(_tmpSpace.storeLhs(lhs)) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + vespalib::ConstArrayRef<double> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(2 == _lh_vector.size()); + assert(2 == rhs_vector.size()); + // convert to radians: + double lat_A = _lh_vector[0] * GeoDegreesDistance::degrees_to_radians; + double lat_B = rhs_vector[0] * GeoDegreesDistance::degrees_to_radians; + double lon_A = _lh_vector[1] * GeoDegreesDistance::degrees_to_radians; + double lon_B = rhs_vector[1] * GeoDegreesDistance::degrees_to_radians; + + double lat_diff = lat_A - lat_B; + double lon_diff = lon_A - lon_B; + + // haversines of differences: + double hav_lat = GeoDegreesDistance::hav(lat_diff); + double hav_lon = GeoDegreesDistance::hav(lon_diff); + + // haversine of central angle between the two points: + double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; + return hav_central_angle; + } + double convert_threshold(double threshold) const override { + return _g_d_helper.convert_threshold(threshold); + } + double to_rawscore(double distance) const override { + return _g_d_helper.to_rawscore(distance); + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + return calc(rhs); + } +}; + +BoundDistanceFunction::UP +GeoDistanceFunctionFactory::for_query_vector(const vespalib::eval::TypedCells& lhs) { + return std::make_unique<BoundGeoDistance>(lhs); +} + +BoundDistanceFunction::UP +GeoDistanceFunctionFactory::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + return std::make_unique<BoundGeoDistance>(lhs); +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h index 46feee19119..4522bc03c9e 100644 --- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include <vespa/eval/eval/typed_cells.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> #include <vespa/vespalib/util/typify.h> @@ -50,4 +51,11 @@ public: } }; +class GeoDistanceFunctionFactory : public DistanceFunctionFactory { +public: + GeoDistanceFunctionFactory() : DistanceFunctionFactory(vespalib::eval::CellType::DOUBLE) {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + } diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp index 43596478a6f..f4f6842715f 100644 --- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "hamming_distance.h" +#include "temporary_vector_store.h" #include <vespa/vespalib/util/binary_hamming_distance.h> using vespalib::typify_invoke; @@ -52,4 +53,63 @@ HammingDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs, return calc(lhs, rhs); } +using vespalib::eval::Int8Float; + +template<typename FloatType> +class BoundHammingDistance : public BoundDistanceFunction { +private: + mutable TemporaryVectorStore<FloatType> _tmpSpace; + const vespalib::ConstArrayRef<FloatType> _lhs_vector; +public: + BoundHammingDistance(const vespalib::eval::TypedCells& lhs) + : _tmpSpace(lhs.size), + _lhs_vector(_tmpSpace.storeLhs(lhs)) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + size_t sz = _lhs_vector.size(); + vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(sz == rhs_vector.size()); + auto a = _lhs_vector.data(); + auto b = rhs_vector.data(); + if constexpr (std::is_same<Int8Float, FloatType>::value) { + return (double) vespalib::binary_hamming_distance(a, b, sz); + } else { + size_t sum = 0; + for (size_t i = 0; i < sz; ++i) { + sum += (_lhs_vector[i] == rhs_vector[i]) ? 0 : 1; + } + return (double)sum; + } + } + double convert_threshold(double threshold) const override { + return threshold; + } + double to_rawscore(double distance) const override { + double score = 1.0 / (1.0 + distance); + return score; + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + // consider optimizing: + return calc(rhs); + } +}; + +template <typename FloatType> +BoundDistanceFunction::UP +HammingDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundHammingDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template <typename FloatType> +BoundDistanceFunction::UP +HammingDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundHammingDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template class HammingDistanceFunctionFactory<Int8Float>; +template class HammingDistanceFunctionFactory<float>; +template class HammingDistanceFunctionFactory<double>; + } diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h index c64fc5b532d..23c855eb137 100644 --- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include <vespa/eval/eval/typed_cells.h> #include <vespa/vespalib/util/typify.h> #include <cmath> @@ -29,4 +30,14 @@ public: double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, double) const override; }; +template <typename FloatType> +class HammingDistanceFunctionFactory : public DistanceFunctionFactory { +public: + HammingDistanceFunctionFactory() + : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>()) + {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + } diff --git a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp index d2693f9f443..292edc1259d 100644 --- a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp @@ -17,8 +17,7 @@ private: double _lhs_norm_sq; public: BoundPrenormalizedAngularDistance(const vespalib::eval::TypedCells& lhs) - : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()), - _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), + : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), _tmpSpace(lhs.size), _lhs(_tmpSpace.storeLhs(lhs)) { |