aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-25 11:51:03 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-25 15:56:19 +0000
commit6201b1979bbc7fd0c156bf77af855b09237a7045 (patch)
tree2739897b18d2b4d061eea5f35d6211c692e4883d
parentb552ce33c561eef8b4440bb2ddc93b24afb8d16a (diff)
add BoundGeoDistance
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp194
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp62
-rw-r--r--searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h8
4 files changed, 165 insertions, 101 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index e5bed3ebae5..74ef23f6eb8 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -18,14 +18,17 @@ using search::attribute::DistanceMetric;
template <typename T>
TypedCells t(const std::vector<T> &v) { return TypedCells(v); }
-void verify_geo_miles(const DistanceFunction *dist_fun,
- const std::vector<double> &p1,
+void verify_geo_miles(const std::vector<double> &p1,
const std::vector<double> &p2,
double exp_miles)
{
+ static GeoDistanceFunctionFactory dff;
TypedCells t1(p1);
TypedCells t2(p2);
- double abstract_distance = dist_fun->calc(t1, t2);
+ auto dist_fun = dff.for_query_vector(t1);
+ double abstract_distance = dist_fun->calc(t2);
+ EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance);
+ EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance);
double raw_score = dist_fun->to_rawscore(abstract_distance);
double km = ((1.0/raw_score)-1.0);
double d_miles = km / 1.609344;
@@ -443,9 +446,6 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
TEST(GeoDegreesTest, gives_expected_score)
{
- auto ct = vespalib::eval::CellType::DOUBLE;
- auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct);
-
std::vector<double> g1_sfo{37.61, -122.38};
std::vector<double> g2_lhr{51.47, -0.46};
std::vector<double> g3_osl{60.20, 11.08};
@@ -456,7 +456,8 @@ TEST(GeoDegreesTest, gives_expected_score)
std::vector<double> g8_lax{33.94, -118.41};
std::vector<double> g9_jfk{40.64, -73.78};
- double g63_a = geodeg->calc(t(g6_trd), t(g3_osl));
+ auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd));
+ double g63_a = geodeg->calc(t(g3_osl));
double g63_r = geodeg->to_rawscore(g63_a);
double g63_km = ((1.0/g63_r)-1.0);
EXPECT_GT(g63_km, 350);
@@ -466,96 +467,95 @@ TEST(GeoDegreesTest, gives_expected_score)
// Great Circle Mapper for airports using
// a more accurate formula - we should agree
// with < 1.0% deviation
- verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0);
- verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367);
- verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196);
- verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604);
- verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927);
- verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012);
- verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417);
- verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337);
- verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586);
-
- verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367);
- verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0);
- verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750);
- verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734);
- verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994);
- verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928);
- verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573);
- verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456);
- verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451);
-
- verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196);
- verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750);
- verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0);
- verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479);
- verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319);
- verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226);
- verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888);
- verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345);
- verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687);
-
- verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604);
- verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734);
- verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479);
- verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0);
- verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989);
- verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623);
- verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414);
- verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294);
- verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786);
-
- verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927);
- verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994);
- verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319);
- verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989);
- verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0);
- verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240);
- verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581);
- verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260);
- verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072);
-
- verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012);
- verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928);
- verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226);
- verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623);
- verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240);
- verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0);
- verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782);
- verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171);
- verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611);
-
- verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417);
- verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573);
- verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888);
- verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414);
- verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581);
- verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782);
- verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0);
- verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488);
- verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950);
-
- verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337);
- verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456);
- verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345);
- verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294);
- verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260);
- verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171);
- verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488);
- verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0);
- verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475);
-
- verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586);
- verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451);
- verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687);
- verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786);
- verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072);
- verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611);
- verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950);
- verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475);
- verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0);
-
+ verify_geo_miles(g1_sfo, g1_sfo, 0);
+ verify_geo_miles(g1_sfo, g2_lhr, 5367);
+ verify_geo_miles(g1_sfo, g3_osl, 5196);
+ verify_geo_miles(g1_sfo, g4_gig, 6604);
+ verify_geo_miles(g1_sfo, g5_hkg, 6927);
+ verify_geo_miles(g1_sfo, g6_trd, 5012);
+ verify_geo_miles(g1_sfo, g7_syd, 7417);
+ verify_geo_miles(g1_sfo, g8_lax, 337);
+ verify_geo_miles(g1_sfo, g9_jfk, 2586);
+
+ verify_geo_miles(g2_lhr, g1_sfo, 5367);
+ verify_geo_miles(g2_lhr, g2_lhr, 0);
+ verify_geo_miles(g2_lhr, g3_osl, 750);
+ verify_geo_miles(g2_lhr, g4_gig, 5734);
+ verify_geo_miles(g2_lhr, g5_hkg, 5994);
+ verify_geo_miles(g2_lhr, g6_trd, 928);
+ verify_geo_miles(g2_lhr, g7_syd, 10573);
+ verify_geo_miles(g2_lhr, g8_lax, 5456);
+ verify_geo_miles(g2_lhr, g9_jfk, 3451);
+
+ verify_geo_miles(g3_osl, g1_sfo, 5196);
+ verify_geo_miles(g3_osl, g2_lhr, 750);
+ verify_geo_miles(g3_osl, g3_osl, 0);
+ verify_geo_miles(g3_osl, g4_gig, 6479);
+ verify_geo_miles(g3_osl, g5_hkg, 5319);
+ verify_geo_miles(g3_osl, g6_trd, 226);
+ verify_geo_miles(g3_osl, g7_syd, 9888);
+ verify_geo_miles(g3_osl, g8_lax, 5345);
+ verify_geo_miles(g3_osl, g9_jfk, 3687);
+
+ verify_geo_miles(g4_gig, g1_sfo, 6604);
+ verify_geo_miles(g4_gig, g2_lhr, 5734);
+ verify_geo_miles(g4_gig, g3_osl, 6479);
+ verify_geo_miles(g4_gig, g4_gig, 0);
+ verify_geo_miles(g4_gig, g5_hkg, 10989);
+ verify_geo_miles(g4_gig, g6_trd, 6623);
+ verify_geo_miles(g4_gig, g7_syd, 8414);
+ verify_geo_miles(g4_gig, g8_lax, 6294);
+ verify_geo_miles(g4_gig, g9_jfk, 4786);
+
+ verify_geo_miles(g5_hkg, g1_sfo, 6927);
+ verify_geo_miles(g5_hkg, g2_lhr, 5994);
+ verify_geo_miles(g5_hkg, g3_osl, 5319);
+ verify_geo_miles(g5_hkg, g4_gig, 10989);
+ verify_geo_miles(g5_hkg, g5_hkg, 0);
+ verify_geo_miles(g5_hkg, g6_trd, 5240);
+ verify_geo_miles(g5_hkg, g7_syd, 4581);
+ verify_geo_miles(g5_hkg, g8_lax, 7260);
+ verify_geo_miles(g5_hkg, g9_jfk, 8072);
+
+ verify_geo_miles(g6_trd, g1_sfo, 5012);
+ verify_geo_miles(g6_trd, g2_lhr, 928);
+ verify_geo_miles(g6_trd, g3_osl, 226);
+ verify_geo_miles(g6_trd, g4_gig, 6623);
+ verify_geo_miles(g6_trd, g5_hkg, 5240);
+ verify_geo_miles(g6_trd, g6_trd, 0);
+ verify_geo_miles(g6_trd, g7_syd, 9782);
+ verify_geo_miles(g6_trd, g8_lax, 5171);
+ verify_geo_miles(g6_trd, g9_jfk, 3611);
+
+ verify_geo_miles(g7_syd, g1_sfo, 7417);
+ verify_geo_miles(g7_syd, g2_lhr, 10573);
+ verify_geo_miles(g7_syd, g3_osl, 9888);
+ verify_geo_miles(g7_syd, g4_gig, 8414);
+ verify_geo_miles(g7_syd, g5_hkg, 4581);
+ verify_geo_miles(g7_syd, g6_trd, 9782);
+ verify_geo_miles(g7_syd, g7_syd, 0);
+ verify_geo_miles(g7_syd, g8_lax, 7488);
+ verify_geo_miles(g7_syd, g9_jfk, 9950);
+
+ verify_geo_miles(g8_lax, g1_sfo, 337);
+ verify_geo_miles(g8_lax, g2_lhr, 5456);
+ verify_geo_miles(g8_lax, g3_osl, 5345);
+ verify_geo_miles(g8_lax, g4_gig, 6294);
+ verify_geo_miles(g8_lax, g5_hkg, 7260);
+ verify_geo_miles(g8_lax, g6_trd, 5171);
+ verify_geo_miles(g8_lax, g7_syd, 7488);
+ verify_geo_miles(g8_lax, g8_lax, 0);
+ verify_geo_miles(g8_lax, g9_jfk, 2475);
+
+ verify_geo_miles(g9_jfk, g1_sfo, 2586);
+ verify_geo_miles(g9_jfk, g2_lhr, 3451);
+ verify_geo_miles(g9_jfk, g3_osl, 3687);
+ verify_geo_miles(g9_jfk, g4_gig, 4786);
+ verify_geo_miles(g9_jfk, g5_hkg, 8072);
+ verify_geo_miles(g9_jfk, g6_trd, 3611);
+ verify_geo_miles(g9_jfk, g7_syd, 9950);
+ verify_geo_miles(g9_jfk, g8_lax, 2475);
+ verify_geo_miles(g9_jfk, g9_jfk, 0);
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 4553f39a525..5473d7db6f5 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -113,10 +113,10 @@ make_distance_function_factory(search::attribute::DistanceMetric variant,
}
return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>();
}
- /*
if (variant == DistanceMetric::GeoDegrees) {
return std::make_unique<GeoDistanceFunctionFactory>();
}
+ /*
if (variant == DistanceMetric::Hamming) {
return std::make_unique<HammingDistanceFunctionFactory>();
}
diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
index bcce75da3ab..973d50ef98d 100644
--- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "geo_degrees_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
@@ -27,11 +28,11 @@ struct CalcGeoDegrees {
double lat_diff = lat_A - lat_B;
double lon_diff = lon_A - lon_B;
-
+
// haversines of differences:
double hav_lat = GeoDegreesDistance::hav(lat_diff);
double hav_lon = GeoDegreesDistance::hav(lon_diff);
-
+
// haversine of central angle between the two points:
double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon;
return hav_central_angle;
@@ -42,9 +43,64 @@ struct CalcGeoDegrees {
double
GeoDegreesDistance::calc(const vespalib::eval::TypedCells& lhs,
- const vespalib::eval::TypedCells& rhs) const
+ const vespalib::eval::TypedCells& rhs) const
{
return typify_invoke<2,TypifyCellType,CalcGeoDegrees>(lhs.type, rhs.type, lhs, rhs);
}
+using vespalib::eval::TypedCells;
+
+class BoundGeoDistance : public BoundDistanceFunction {
+private:
+ mutable TemporaryVectorStore<double> _tmpSpace;
+ const vespalib::ConstArrayRef<double> _lh_vector;
+ static GeoDegreesDistance _g_d_helper;
+public:
+ BoundGeoDistance(const vespalib::eval::TypedCells& lhs)
+ : BoundDistanceFunction(vespalib::eval::CellType::DOUBLE),
+ _tmpSpace(lhs.size),
+ _lh_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ vespalib::ConstArrayRef<double> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(2 == _lh_vector.size());
+ assert(2 == rhs_vector.size());
+ // convert to radians:
+ double lat_A = _lh_vector[0] * GeoDegreesDistance::degrees_to_radians;
+ double lat_B = rhs_vector[0] * GeoDegreesDistance::degrees_to_radians;
+ double lon_A = _lh_vector[1] * GeoDegreesDistance::degrees_to_radians;
+ double lon_B = rhs_vector[1] * GeoDegreesDistance::degrees_to_radians;
+
+ double lat_diff = lat_A - lat_B;
+ double lon_diff = lon_A - lon_B;
+
+ // haversines of differences:
+ double hav_lat = GeoDegreesDistance::hav(lat_diff);
+ double hav_lon = GeoDegreesDistance::hav(lon_diff);
+
+ // haversine of central angle between the two points:
+ double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon;
+ return hav_central_angle;
+ }
+ double convert_threshold(double threshold) const override {
+ return _g_d_helper.convert_threshold(threshold);
+ }
+ double to_rawscore(double distance) const override {
+ return _g_d_helper.to_rawscore(distance);
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ return calc(rhs);
+ }
+};
+
+BoundDistanceFunction::UP
+GeoDistanceFunctionFactory::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundGeoDistance>(lhs);
+}
+
+BoundDistanceFunction::UP
+GeoDistanceFunctionFactory::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundGeoDistance>(lhs);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
index 46feee19119..4522bc03c9e 100644
--- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <vespa/vespalib/util/typify.h>
@@ -50,4 +51,11 @@ public:
}
};
+class GeoDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ GeoDistanceFunctionFactory() : DistanceFunctionFactory(vespalib::eval::CellType::DOUBLE) {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}