diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-04-25 13:11:39 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-04-25 15:56:19 +0000 |
commit | 807aeb083db71c2504c38dcd684e7556952c9047 (patch) | |
tree | 9758b93b2baa3a62e565c7b16781939d1e4fa49f | |
parent | 7d3241ca8b5bd2a4edb0f35d11883dd6d497faa2 (diff) |
add BoundHammingDistance
4 files changed, 82 insertions, 3 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 74ef23f6eb8..3b6d3a70628 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -394,6 +394,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) TEST(DistanceFunctionsTest, hamming_gives_expected_score) { + static HammingDistanceFunctionFactory<Int8Float> dff; auto ct = vespalib::eval::CellType::DOUBLE; auto hamming = make_distance_function(DistanceMetric::Hamming, ct); @@ -410,6 +411,9 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double h0 = hamming->calc(t(p), t(p)); EXPECT_EQ(h0, 0.0); EXPECT_EQ(hamming->to_rawscore(h0), 1.0); + auto dist_fun = dff.for_query_vector(t(p)); + EXPECT_EQ(dist_fun->calc(t(p)), 0.0); + EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0); } double d12 = hamming->calc(t(points[1]), t(points[2])); EXPECT_EQ(d12, 3.0); @@ -442,6 +446,8 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 }; // expect diff: 1 2 1 1 7 EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0); + auto dist_fun = dff.for_query_vector(TypedCells(bytes_a)); + EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0); } TEST(GeoDegreesTest, gives_expected_score) diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 30bca0d4212..8a004e0d3c9 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -115,11 +115,13 @@ make_distance_function_factory(search::attribute::DistanceMetric variant, if (variant == DistanceMetric::GeoDegrees) { return std::make_unique<GeoDistanceFunctionFactory>(); } - /* if (variant == DistanceMetric::Hamming) { - return std::make_unique<HammingDistanceFunctionFactory>(); + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<HammingDistanceFunctionFactory<double>>(); + case CellType::INT8: return std::make_unique<HammingDistanceFunctionFactory<vespalib::eval::Int8Float>>(); + default: return std::make_unique<HammingDistanceFunctionFactory<float>>(); + } } - */ auto df = make_distance_function(variant, cell_type); return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df)); } diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp index 43596478a6f..f4f6842715f 100644 --- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "hamming_distance.h" +#include "temporary_vector_store.h" #include <vespa/vespalib/util/binary_hamming_distance.h> using vespalib::typify_invoke; @@ -52,4 +53,63 @@ HammingDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs, return calc(lhs, rhs); } +using vespalib::eval::Int8Float; + +template<typename FloatType> +class BoundHammingDistance : public BoundDistanceFunction { +private: + mutable TemporaryVectorStore<FloatType> _tmpSpace; + const vespalib::ConstArrayRef<FloatType> _lhs_vector; +public: + BoundHammingDistance(const vespalib::eval::TypedCells& lhs) + : _tmpSpace(lhs.size), + _lhs_vector(_tmpSpace.storeLhs(lhs)) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + size_t sz = _lhs_vector.size(); + vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(sz == rhs_vector.size()); + auto a = _lhs_vector.data(); + auto b = rhs_vector.data(); + if constexpr (std::is_same<Int8Float, FloatType>::value) { + return (double) vespalib::binary_hamming_distance(a, b, sz); + } else { + size_t sum = 0; + for (size_t i = 0; i < sz; ++i) { + sum += (_lhs_vector[i] == rhs_vector[i]) ? 0 : 1; + } + return (double)sum; + } + } + double convert_threshold(double threshold) const override { + return threshold; + } + double to_rawscore(double distance) const override { + double score = 1.0 / (1.0 + distance); + return score; + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + // consider optimizing: + return calc(rhs); + } +}; + +template <typename FloatType> +BoundDistanceFunction::UP +HammingDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundHammingDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template <typename FloatType> +BoundDistanceFunction::UP +HammingDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundHammingDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template class HammingDistanceFunctionFactory<Int8Float>; +template class HammingDistanceFunctionFactory<float>; +template class HammingDistanceFunctionFactory<double>; + } diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h index c64fc5b532d..23c855eb137 100644 --- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include <vespa/eval/eval/typed_cells.h> #include <vespa/vespalib/util/typify.h> #include <cmath> @@ -29,4 +30,14 @@ public: double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, double) const override; }; +template <typename FloatType> +class HammingDistanceFunctionFactory : public DistanceFunctionFactory { +public: + HammingDistanceFunctionFactory() + : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>()) + {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + } |