diff options
Diffstat (limited to 'searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp')
-rw-r--r-- | searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 391e2d91d08..eeae12e1695 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -20,6 +20,16 @@ using search::attribute::DistanceMetric; template <typename T> TypedCells t(const std::vector<T> &v) { return TypedCells(v); } +template<typename T> +struct EmptyCells { + explicit EmptyCells(size_t elems) : _zero(elems, 0), cells(_zero) { cells.size = 0; } + std::vector<T> _zero; + TypedCells cells; +}; + +template <typename T> +EmptyCells<T> e(size_t elems) { return EmptyCells<T>(elems); } + void verify_geo_miles(const std::vector<double> &p1, const std::vector<double> &p2, double exp_miles) @@ -49,6 +59,15 @@ void verify_geo_miles(const std::vector<double> &p1, } } +template<typename T> +void verifyInvalidQueryVector(DistanceFunctionFactory & dff, double expected_distance_to_origo) { + std::vector<T> origo = {0,0,0}; + EXPECT_FLOAT_EQ(expected_distance_to_origo, dff.for_query_vector(t(origo))->calc(e<double>(origo.size()).cells)); + EXPECT_FLOAT_EQ(expected_distance_to_origo, dff.for_query_vector(t(origo))->calc(e<float>(origo.size()).cells)); + EXPECT_FLOAT_EQ(expected_distance_to_origo, dff.for_query_vector(t(origo))->calc(e<Int8Float>(origo.size()).cells)); + EXPECT_FLOAT_EQ(expected_distance_to_origo, dff.for_query_vector(t(origo))->calc(e<vespalib::BFloat16>(origo.size()).cells)); +} + double computeEuclideanChecked(TypedCells a, TypedCells b) { static EuclideanDistanceFunctionFactory<Int8Float> i8f_dff; static EuclideanDistanceFunctionFactory<float> flt_dff; @@ -92,6 +111,7 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) EXPECT_EQ(d12, 2.0); EuclideanDistanceFunctionFactory<double> dff; + verifyInvalidQueryVector<double>(dff, 0.0); auto euclid = dff.for_query_vector(t(p0)); EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0))); double threshold = euclid->convert_threshold(8.0); @@ -128,10 +148,7 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) EXPECT_EQ(computeEuclideanChecked(t(p6), t(p6)), 0.0); // smoke test for bfloat16: - std::vector<vespalib::BFloat16> bf16v; - bf16v.emplace_back(1.0); - bf16v.emplace_back(1.0); - bf16v.emplace_back(1.0); + std::vector<vespalib::BFloat16> bf16v{1.0, 1.0, 1.0}; EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p0)), 3.0); EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p1)), 2.0); EXPECT_EQ(computeEuclideanChecked(t(bf16v), t(p2)), 2.0); @@ -188,6 +205,7 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) AngularDistanceFunctionFactory<double> dff; auto angular = dff.for_query_vector(t(p0)); + verifyInvalidQueryVector<double>(dff, 1.0); constexpr double pi = 3.14159265358979323846; double a12 = computeAngularChecked(t(p1), t(p2)); double a13 = computeAngularChecked(t(p1), t(p3)); @@ -315,6 +333,7 @@ TEST(DistanceFunctionsTest, prenormalized_angular_gives_expected_score) std::vector<double> p8{3.0, 0.0, 0.0}; PrenormalizedAngularDistanceFunctionFactory<double> dff; + verifyInvalidQueryVector<double>(dff, 1.0); auto pnad = dff.for_query_vector(t(p0)); double i12 = computePrenormalizedAngularChecked(t(p1), t(p2)); @@ -360,7 +379,8 @@ TEST(DistanceFunctionsTest, prenormalized_angular_gives_expected_score) TEST(DistanceFunctionsTest, hamming_gives_expected_score) { - static HammingDistanceFunctionFactory<double> dff; + HammingDistanceFunctionFactory<double> dff; + verifyInvalidQueryVector<double>(dff, 0.0); std::vector<std::vector<double>> points{{0.0, 0.0, 0.0}, {1.0, 0.0, 0.0}, @@ -376,6 +396,7 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) EXPECT_EQ(h0, 0.0); EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0); } + double d12 = dff.for_query_vector(t(points[1]))->calc(t(points[2])); EXPECT_EQ(d12, 3.0); EXPECT_DOUBLE_EQ(hamming->to_rawscore(d12), 1.0/(1.0 + 3.0)); @@ -579,6 +600,9 @@ TEST(DistanceFunctionsTest, transformed_mips_basic_scores) std::vector<double> p4{0.5, 0.5, sq_root_half}; std::vector<double> p5{0.0,-1.0, 0.0}; + MipsDistanceFunctionFactory<double> dff; + verifyInvalidQueryVector<double>(dff, 0.0); + double i12 = computeTransformedMipsChecked(t(p1), t(p2)); double i13 = computeTransformedMipsChecked(t(p1), t(p3)); double i23 = computeTransformedMipsChecked(t(p2), t(p3)); |