aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-04-26 13:59:06 +0200
committerGitHub <noreply@github.com>2023-04-26 13:59:06 +0200
commit797091a6867be9543c6f1d08f0189cbe7c12e0b3 (patch)
tree5fbb7b6277c31a1ed46532108c28c7f1d3f74fda
parent01cc25458c74d2902879087919f67622600ffc65 (diff)
parentb2401a91381d1f66ef316d850d469181f06f0d36 (diff)
Merge pull request #26849 from vespa-engine/arnej/add-bound-hamming
add bound hamming, geo distance
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp200
-rw-r--r--searchlib/src/vespa/searchlib/tensor/angular_distance.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h9
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp59
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp61
-rw-r--r--searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h8
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp60
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hamming_distance.h11
-rw-r--r--searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp3
10 files changed, 273 insertions, 144 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index 600c5ae0646..9b8ad0d26ce 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -18,14 +18,17 @@ using search::attribute::DistanceMetric;
template <typename T>
TypedCells t(const std::vector<T> &v) { return TypedCells(v); }
-void verify_geo_miles(const DistanceFunction *dist_fun,
- const std::vector<double> &p1,
+void verify_geo_miles(const std::vector<double> &p1,
const std::vector<double> &p2,
double exp_miles)
{
+ static GeoDistanceFunctionFactory dff;
TypedCells t1(p1);
TypedCells t2(p2);
- double abstract_distance = dist_fun->calc(t1, t2);
+ auto dist_fun = dff.for_query_vector(t1);
+ double abstract_distance = dist_fun->calc(t2);
+ EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance);
+ EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance);
double raw_score = dist_fun->to_rawscore(abstract_distance);
double km = ((1.0/raw_score)-1.0);
double d_miles = km / 1.609344;
@@ -391,6 +394,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
TEST(DistanceFunctionsTest, hamming_gives_expected_score)
{
+ static HammingDistanceFunctionFactory<Int8Float> dff;
auto ct = vespalib::eval::CellType::DOUBLE;
auto hamming = make_distance_function(DistanceMetric::Hamming, ct);
@@ -407,6 +411,9 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
double h0 = hamming->calc(t(p), t(p));
EXPECT_EQ(h0, 0.0);
EXPECT_EQ(hamming->to_rawscore(h0), 1.0);
+ auto dist_fun = dff.for_query_vector(t(p));
+ EXPECT_EQ(dist_fun->calc(t(p)), 0.0);
+ EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0);
}
double d12 = hamming->calc(t(points[1]), t(points[2]));
EXPECT_EQ(d12, 3.0);
@@ -439,13 +446,12 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 };
// expect diff: 1 2 1 1 7
EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0);
+ auto dist_fun = dff.for_query_vector(TypedCells(bytes_a));
+ EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0);
}
TEST(GeoDegreesTest, gives_expected_score)
{
- auto ct = vespalib::eval::CellType::DOUBLE;
- auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct);
-
std::vector<double> g1_sfo{37.61, -122.38};
std::vector<double> g2_lhr{51.47, -0.46};
std::vector<double> g3_osl{60.20, 11.08};
@@ -456,7 +462,8 @@ TEST(GeoDegreesTest, gives_expected_score)
std::vector<double> g8_lax{33.94, -118.41};
std::vector<double> g9_jfk{40.64, -73.78};
- double g63_a = geodeg->calc(t(g6_trd), t(g3_osl));
+ auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd));
+ double g63_a = geodeg->calc(t(g3_osl));
double g63_r = geodeg->to_rawscore(g63_a);
double g63_km = ((1.0/g63_r)-1.0);
EXPECT_GT(g63_km, 350);
@@ -466,96 +473,95 @@ TEST(GeoDegreesTest, gives_expected_score)
// Great Circle Mapper for airports using
// a more accurate formula - we should agree
// with < 1.0% deviation
- verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0);
- verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367);
- verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196);
- verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604);
- verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927);
- verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012);
- verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417);
- verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337);
- verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586);
-
- verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367);
- verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0);
- verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750);
- verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734);
- verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994);
- verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928);
- verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573);
- verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456);
- verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451);
-
- verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196);
- verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750);
- verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0);
- verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479);
- verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319);
- verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226);
- verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888);
- verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345);
- verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687);
-
- verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604);
- verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734);
- verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479);
- verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0);
- verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989);
- verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623);
- verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414);
- verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294);
- verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786);
-
- verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927);
- verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994);
- verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319);
- verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989);
- verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0);
- verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240);
- verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581);
- verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260);
- verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072);
-
- verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012);
- verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928);
- verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226);
- verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623);
- verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240);
- verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0);
- verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782);
- verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171);
- verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611);
-
- verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417);
- verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573);
- verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888);
- verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414);
- verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581);
- verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782);
- verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0);
- verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488);
- verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950);
-
- verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337);
- verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456);
- verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345);
- verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294);
- verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260);
- verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171);
- verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488);
- verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0);
- verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475);
-
- verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586);
- verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451);
- verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687);
- verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786);
- verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072);
- verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611);
- verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950);
- verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475);
- verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0);
-
+ verify_geo_miles(g1_sfo, g1_sfo, 0);
+ verify_geo_miles(g1_sfo, g2_lhr, 5367);
+ verify_geo_miles(g1_sfo, g3_osl, 5196);
+ verify_geo_miles(g1_sfo, g4_gig, 6604);
+ verify_geo_miles(g1_sfo, g5_hkg, 6927);
+ verify_geo_miles(g1_sfo, g6_trd, 5012);
+ verify_geo_miles(g1_sfo, g7_syd, 7417);
+ verify_geo_miles(g1_sfo, g8_lax, 337);
+ verify_geo_miles(g1_sfo, g9_jfk, 2586);
+
+ verify_geo_miles(g2_lhr, g1_sfo, 5367);
+ verify_geo_miles(g2_lhr, g2_lhr, 0);
+ verify_geo_miles(g2_lhr, g3_osl, 750);
+ verify_geo_miles(g2_lhr, g4_gig, 5734);
+ verify_geo_miles(g2_lhr, g5_hkg, 5994);
+ verify_geo_miles(g2_lhr, g6_trd, 928);
+ verify_geo_miles(g2_lhr, g7_syd, 10573);
+ verify_geo_miles(g2_lhr, g8_lax, 5456);
+ verify_geo_miles(g2_lhr, g9_jfk, 3451);
+
+ verify_geo_miles(g3_osl, g1_sfo, 5196);
+ verify_geo_miles(g3_osl, g2_lhr, 750);
+ verify_geo_miles(g3_osl, g3_osl, 0);
+ verify_geo_miles(g3_osl, g4_gig, 6479);
+ verify_geo_miles(g3_osl, g5_hkg, 5319);
+ verify_geo_miles(g3_osl, g6_trd, 226);
+ verify_geo_miles(g3_osl, g7_syd, 9888);
+ verify_geo_miles(g3_osl, g8_lax, 5345);
+ verify_geo_miles(g3_osl, g9_jfk, 3687);
+
+ verify_geo_miles(g4_gig, g1_sfo, 6604);
+ verify_geo_miles(g4_gig, g2_lhr, 5734);
+ verify_geo_miles(g4_gig, g3_osl, 6479);
+ verify_geo_miles(g4_gig, g4_gig, 0);
+ verify_geo_miles(g4_gig, g5_hkg, 10989);
+ verify_geo_miles(g4_gig, g6_trd, 6623);
+ verify_geo_miles(g4_gig, g7_syd, 8414);
+ verify_geo_miles(g4_gig, g8_lax, 6294);
+ verify_geo_miles(g4_gig, g9_jfk, 4786);
+
+ verify_geo_miles(g5_hkg, g1_sfo, 6927);
+ verify_geo_miles(g5_hkg, g2_lhr, 5994);
+ verify_geo_miles(g5_hkg, g3_osl, 5319);
+ verify_geo_miles(g5_hkg, g4_gig, 10989);
+ verify_geo_miles(g5_hkg, g5_hkg, 0);
+ verify_geo_miles(g5_hkg, g6_trd, 5240);
+ verify_geo_miles(g5_hkg, g7_syd, 4581);
+ verify_geo_miles(g5_hkg, g8_lax, 7260);
+ verify_geo_miles(g5_hkg, g9_jfk, 8072);
+
+ verify_geo_miles(g6_trd, g1_sfo, 5012);
+ verify_geo_miles(g6_trd, g2_lhr, 928);
+ verify_geo_miles(g6_trd, g3_osl, 226);
+ verify_geo_miles(g6_trd, g4_gig, 6623);
+ verify_geo_miles(g6_trd, g5_hkg, 5240);
+ verify_geo_miles(g6_trd, g6_trd, 0);
+ verify_geo_miles(g6_trd, g7_syd, 9782);
+ verify_geo_miles(g6_trd, g8_lax, 5171);
+ verify_geo_miles(g6_trd, g9_jfk, 3611);
+
+ verify_geo_miles(g7_syd, g1_sfo, 7417);
+ verify_geo_miles(g7_syd, g2_lhr, 10573);
+ verify_geo_miles(g7_syd, g3_osl, 9888);
+ verify_geo_miles(g7_syd, g4_gig, 8414);
+ verify_geo_miles(g7_syd, g5_hkg, 4581);
+ verify_geo_miles(g7_syd, g6_trd, 9782);
+ verify_geo_miles(g7_syd, g7_syd, 0);
+ verify_geo_miles(g7_syd, g8_lax, 7488);
+ verify_geo_miles(g7_syd, g9_jfk, 9950);
+
+ verify_geo_miles(g8_lax, g1_sfo, 337);
+ verify_geo_miles(g8_lax, g2_lhr, 5456);
+ verify_geo_miles(g8_lax, g3_osl, 5345);
+ verify_geo_miles(g8_lax, g4_gig, 6294);
+ verify_geo_miles(g8_lax, g5_hkg, 7260);
+ verify_geo_miles(g8_lax, g6_trd, 5171);
+ verify_geo_miles(g8_lax, g7_syd, 7488);
+ verify_geo_miles(g8_lax, g8_lax, 0);
+ verify_geo_miles(g8_lax, g9_jfk, 2475);
+
+ verify_geo_miles(g9_jfk, g1_sfo, 2586);
+ verify_geo_miles(g9_jfk, g2_lhr, 3451);
+ verify_geo_miles(g9_jfk, g3_osl, 3687);
+ verify_geo_miles(g9_jfk, g4_gig, 4786);
+ verify_geo_miles(g9_jfk, g5_hkg, 8072);
+ verify_geo_miles(g9_jfk, g6_trd, 3611);
+ verify_geo_miles(g9_jfk, g7_syd, 9950);
+ verify_geo_miles(g9_jfk, g8_lax, 2475);
+ verify_geo_miles(g9_jfk, g9_jfk, 0);
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
index efc1170baf5..a7ae02bb9f4 100644
--- a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
@@ -61,8 +61,7 @@ private:
double _lhs_norm_sq;
public:
BoundAngularDistance(const vespalib::eval::TypedCells& lhs)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
_tmpSpace(lhs.size),
_lhs(_tmpSpace.storeLhs(lhs))
{
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
index 5d602a52227..c072d6de8e5 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -20,20 +20,13 @@ namespace search::tensor {
* mutable temporary storage.
*/
class BoundDistanceFunction : public DistanceConverter {
-private:
- vespalib::eval::CellType _expect_cell_type;
public:
using UP = std::unique_ptr<BoundDistanceFunction>;
- BoundDistanceFunction(vespalib::eval::CellType expected) : _expect_cell_type(expected) {}
+ BoundDistanceFunction() = default;
virtual ~BoundDistanceFunction() = default;
- // input vectors will be converted to this cell type:
- vespalib::eval::CellType expected_cell_type() const {
- return _expect_cell_type;
- }
-
// calculate internal distance (comparable)
virtual double calc(const vespalib::eval::TypedCells& rhs) const = 0;
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 4553f39a525..c088d498f0f 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -55,8 +55,7 @@ class SimpleBoundDistanceFunction : public BoundDistanceFunction {
public:
SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs,
const DistanceFunction &df)
- : BoundDistanceFunction(lhs.type),
- _lhs(lhs),
+ : _lhs(lhs),
_df(df)
{}
@@ -94,35 +93,35 @@ std::unique_ptr<DistanceFunctionFactory>
make_distance_function_factory(search::attribute::DistanceMetric variant,
vespalib::eval::CellType cell_type)
{
- if (variant == DistanceMetric::Angular) {
- if (cell_type == CellType::DOUBLE) {
- return std::make_unique<AngularDistanceFunctionFactory<double>>();
- }
- return std::make_unique<AngularDistanceFunctionFactory<float>>();
- }
- if (variant == DistanceMetric::Euclidean) {
- switch (cell_type) {
- case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
- case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
- default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
- }
- }
- if (variant == DistanceMetric::PrenormalizedAngular) {
- if (cell_type == CellType::DOUBLE) {
- return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>();
- }
- return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>();
- }
- /*
- if (variant == DistanceMetric::GeoDegrees) {
- return std::make_unique<GeoDistanceFunctionFactory>();
- }
- if (variant == DistanceMetric::Hamming) {
- return std::make_unique<HammingDistanceFunctionFactory>();
+ switch (variant) {
+ case DistanceMetric::Angular:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<AngularDistanceFunctionFactory<double>>();
+ default: return std::make_unique<AngularDistanceFunctionFactory<float>>();
+ }
+ case DistanceMetric::Euclidean:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
+ }
+ case DistanceMetric::InnerProduct:
+ case DistanceMetric::PrenormalizedAngular:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>();
+ default: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>();
+ }
+ case DistanceMetric::GeoDegrees:
+ return std::make_unique<GeoDistanceFunctionFactory>();
+ case DistanceMetric::Hamming:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<HammingDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<HammingDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<HammingDistanceFunctionFactory<float>>();
+ }
}
- */
- auto df = make_distance_function(variant, cell_type);
- return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df));
+ // not reached:
+ return {};
}
}
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
index 9c37b191637..7995c87d055 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
@@ -62,8 +62,7 @@ private:
static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
public:
BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
_tmpSpace(lhs.size),
_lhs_vector(_tmpSpace.storeLhs(lhs))
{}
diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
index bcce75da3ab..38ba8205c90 100644
--- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "geo_degrees_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
@@ -27,11 +28,11 @@ struct CalcGeoDegrees {
double lat_diff = lat_A - lat_B;
double lon_diff = lon_A - lon_B;
-
+
// haversines of differences:
double hav_lat = GeoDegreesDistance::hav(lat_diff);
double hav_lon = GeoDegreesDistance::hav(lon_diff);
-
+
// haversine of central angle between the two points:
double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon;
return hav_central_angle;
@@ -42,9 +43,63 @@ struct CalcGeoDegrees {
double
GeoDegreesDistance::calc(const vespalib::eval::TypedCells& lhs,
- const vespalib::eval::TypedCells& rhs) const
+ const vespalib::eval::TypedCells& rhs) const
{
return typify_invoke<2,TypifyCellType,CalcGeoDegrees>(lhs.type, rhs.type, lhs, rhs);
}
+using vespalib::eval::TypedCells;
+
+class BoundGeoDistance : public BoundDistanceFunction {
+private:
+ mutable TemporaryVectorStore<double> _tmpSpace;
+ const vespalib::ConstArrayRef<double> _lh_vector;
+ static GeoDegreesDistance _g_d_helper;
+public:
+ BoundGeoDistance(const vespalib::eval::TypedCells& lhs)
+ : _tmpSpace(lhs.size),
+ _lh_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ vespalib::ConstArrayRef<double> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(2 == _lh_vector.size());
+ assert(2 == rhs_vector.size());
+ // convert to radians:
+ double lat_A = _lh_vector[0] * GeoDegreesDistance::degrees_to_radians;
+ double lat_B = rhs_vector[0] * GeoDegreesDistance::degrees_to_radians;
+ double lon_A = _lh_vector[1] * GeoDegreesDistance::degrees_to_radians;
+ double lon_B = rhs_vector[1] * GeoDegreesDistance::degrees_to_radians;
+
+ double lat_diff = lat_A - lat_B;
+ double lon_diff = lon_A - lon_B;
+
+ // haversines of differences:
+ double hav_lat = GeoDegreesDistance::hav(lat_diff);
+ double hav_lon = GeoDegreesDistance::hav(lon_diff);
+
+ // haversine of central angle between the two points:
+ double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon;
+ return hav_central_angle;
+ }
+ double convert_threshold(double threshold) const override {
+ return _g_d_helper.convert_threshold(threshold);
+ }
+ double to_rawscore(double distance) const override {
+ return _g_d_helper.to_rawscore(distance);
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ return calc(rhs);
+ }
+};
+
+BoundDistanceFunction::UP
+GeoDistanceFunctionFactory::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundGeoDistance>(lhs);
+}
+
+BoundDistanceFunction::UP
+GeoDistanceFunctionFactory::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundGeoDistance>(lhs);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
index 46feee19119..4522bc03c9e 100644
--- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <vespa/vespalib/util/typify.h>
@@ -50,4 +51,11 @@ public:
}
};
+class GeoDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ GeoDistanceFunctionFactory() : DistanceFunctionFactory(vespalib::eval::CellType::DOUBLE) {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp
index 43596478a6f..f4f6842715f 100644
--- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "hamming_distance.h"
+#include "temporary_vector_store.h"
#include <vespa/vespalib/util/binary_hamming_distance.h>
using vespalib::typify_invoke;
@@ -52,4 +53,63 @@ HammingDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs,
return calc(lhs, rhs);
}
+using vespalib::eval::Int8Float;
+
+template<typename FloatType>
+class BoundHammingDistance : public BoundDistanceFunction {
+private:
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs_vector;
+public:
+ BoundHammingDistance(const vespalib::eval::TypedCells& lhs)
+ : _tmpSpace(lhs.size),
+ _lhs_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ size_t sz = _lhs_vector.size();
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(sz == rhs_vector.size());
+ auto a = _lhs_vector.data();
+ auto b = rhs_vector.data();
+ if constexpr (std::is_same<Int8Float, FloatType>::value) {
+ return (double) vespalib::binary_hamming_distance(a, b, sz);
+ } else {
+ size_t sum = 0;
+ for (size_t i = 0; i < sz; ++i) {
+ sum += (_lhs_vector[i] == rhs_vector[i]) ? 0 : 1;
+ }
+ return (double)sum;
+ }
+ }
+ double convert_threshold(double threshold) const override {
+ return threshold;
+ }
+ double to_rawscore(double distance) const override {
+ double score = 1.0 / (1.0 + distance);
+ return score;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ // consider optimizing:
+ return calc(rhs);
+ }
+};
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+HammingDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundHammingDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+HammingDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundHammingDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template class HammingDistanceFunctionFactory<Int8Float>;
+template class HammingDistanceFunctionFactory<float>;
+template class HammingDistanceFunctionFactory<double>;
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h
index c64fc5b532d..23c855eb137 100644
--- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/util/typify.h>
#include <cmath>
@@ -29,4 +30,14 @@ public:
double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, double) const override;
};
+template <typename FloatType>
+class HammingDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ HammingDistanceFunctionFactory()
+ : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>())
+ {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp
index d2693f9f443..292edc1259d 100644
--- a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp
@@ -17,8 +17,7 @@ private:
double _lhs_norm_sq;
public:
BoundPrenormalizedAngularDistance(const vespalib::eval::TypedCells& lhs)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
_tmpSpace(lhs.size),
_lhs(_tmpSpace.storeLhs(lhs))
{