diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-05-05 15:22:26 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2023-05-05 15:22:26 +0000 |
commit | 36ac8ebea478d02dfbd4e914e85e4f56d3e11cf2 (patch) | |
tree | 869f68f84614bd90307b57a3bf24299e56efa8ae /searchlib | |
parent | e0ed2a2f68f470a4348b6ee19c77b4d22eded8de (diff) |
Make it possible to configure dotproduct distance metric.
Diffstat (limited to 'searchlib')
5 files changed, 26 insertions, 20 deletions
diff --git a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp index ab36213246e..92b310ffe1e 100644 --- a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp +++ b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp @@ -180,6 +180,16 @@ assertCollectionType(CollectionType exp, AttributesConfig::Attribute::Collection EXPECT_EQUAL(exp.createIfNonExistant(), out.collectionType().createIfNonExistant()); } +void +expect_distance_metric(AttributesConfig::Attribute::Distancemetric in_metric, + DistanceMetric out_metric) +{ + AttributesConfig::Attribute a; + a.distancemetric = in_metric; + auto out = ConfigConverter::convert(a); + EXPECT_TRUE(out.distance_metric() == out_metric); +} + TEST("require that config can be converted") { @@ -254,16 +264,13 @@ TEST("require that config can be converted") EXPECT_TRUE(out.distance_metric() == DistanceMetric::Euclidean); } { // distance metric (explicit) - CACA a; - a.distancemetric = AttributesConfig::Attribute::Distancemetric::GEODEGREES; - auto out = ConfigConverter::convert(a); - EXPECT_TRUE(out.distance_metric() == DistanceMetric::GeoDegrees); - } - { // distance metric (explicit) - CACA a; - a.distancemetric = AttributesConfig::Attribute::Distancemetric::INNERPRODUCT; - auto out = ConfigConverter::convert(a); - EXPECT_TRUE(out.distance_metric() == DistanceMetric::InnerProduct); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::EUCLIDEAN, DistanceMetric::Euclidean); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::ANGULAR, DistanceMetric::Angular); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::GEODEGREES, DistanceMetric::GeoDegrees); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::HAMMING, DistanceMetric::Hamming); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::INNERPRODUCT, DistanceMetric::InnerProduct); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::PRENORMALIZED_ANGULAR, DistanceMetric::PrenormalizedAngular); + expect_distance_metric(AttributesConfig::Attribute::Distancemetric::DOTPRODUCT, DistanceMetric::Dotproduct); } { // hnsw index default params (enabled) CACA a; diff --git a/searchlib/src/vespa/searchcommon/attribute/distance_metric.h b/searchlib/src/vespa/searchcommon/attribute/distance_metric.h index c157f6abb28..9f9f45810b9 100644 --- a/searchlib/src/vespa/searchcommon/attribute/distance_metric.h +++ b/searchlib/src/vespa/searchcommon/attribute/distance_metric.h @@ -4,6 +4,6 @@ namespace search::attribute { -enum class DistanceMetric { Euclidean, Angular, GeoDegrees, InnerProduct, Hamming, PrenormalizedAngular, TransformedMips }; +enum class DistanceMetric { Euclidean, Angular, GeoDegrees, InnerProduct, Hamming, PrenormalizedAngular, Dotproduct }; } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp index 0edab90f089..122c2c0c55e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp @@ -29,7 +29,7 @@ const vespalib::string angular = "angular"; const vespalib::string geodegrees = "geodegrees"; const vespalib::string innerproduct = "innerproduct"; const vespalib::string prenormalized_angular = "prenormalized_angular"; -const vespalib::string transformed_mips = "transformed_mips"; +const vespalib::string dotproduct = "dotproduct"; const vespalib::string hamming = "hamming"; const vespalib::string doc_id_limit_tag = "docIdLimit"; const vespalib::string enumerated_tag = "enumerated"; @@ -104,7 +104,7 @@ to_string(DistanceMetric metric) case DistanceMetric::InnerProduct: return innerproduct; case DistanceMetric::Hamming: return hamming; case DistanceMetric::PrenormalizedAngular: return prenormalized_angular; - case DistanceMetric::TransformedMips: return transformed_mips; + case DistanceMetric::Dotproduct: return dotproduct; } throw vespalib::IllegalArgumentException("Unknown distance metric " + std::to_string(static_cast<int>(metric))); } @@ -122,8 +122,8 @@ to_distance_metric(const vespalib::string& metric) return DistanceMetric::InnerProduct; } else if (metric == prenormalized_angular) { return DistanceMetric::PrenormalizedAngular; - } else if (metric == transformed_mips) { - return DistanceMetric::TransformedMips; + } else if (metric == dotproduct) { + return DistanceMetric::Dotproduct; } else if (metric == hamming) { return DistanceMetric::Hamming; } else { diff --git a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp index 2119f441a14..7f04efd940b 100644 --- a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp +++ b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp @@ -136,10 +136,9 @@ ConfigConverter::convert(const AttributesConfig::Attribute & cfg) break; case CfgDm::PRENORMALIZED_ANGULAR: dm = DistanceMetric::PrenormalizedAngular; - /* - case CfgDm::TRANSFORMED_MIPS: - dm = DistanceMetric::TransformedMips; - */ + break; + case CfgDm::DOTPRODUCT: + dm = DistanceMetric::Dotproduct; break; } retval.set_distance_metric(dm); diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index a338bf85e43..68988ef6308 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -39,7 +39,7 @@ make_distance_function_factory(search::attribute::DistanceMetric variant, case CellType::DOUBLE: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>(); default: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>(); } - case DistanceMetric::TransformedMips: + case DistanceMetric::Dotproduct: switch (cell_type) { case CellType::DOUBLE: return std::make_unique<MipsDistanceFunctionFactory<double>>(); case CellType::INT8: return std::make_unique<MipsDistanceFunctionFactory<vespalib::eval::Int8Float>>(); |