aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-05-05 15:22:26 +0000
committerGeir Storli <geirst@yahooinc.com>2023-05-05 15:22:26 +0000
commit36ac8ebea478d02dfbd4e914e85e4f56d3e11cf2 (patch)
tree869f68f84614bd90307b57a3bf24299e56efa8ae /searchlib
parente0ed2a2f68f470a4348b6ee19c77b4d22eded8de (diff)
Make it possible to configure dotproduct distance metric.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp27
-rw-r--r--searchlib/src/vespa/searchcommon/attribute/distance_metric.h2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_header.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/attribute/configconverter.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp2
5 files changed, 26 insertions, 20 deletions
diff --git a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp
index ab36213246e..92b310ffe1e 100644
--- a/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp
+++ b/searchlib/src/tests/attribute/attributemanager/attributemanager_test.cpp
@@ -180,6 +180,16 @@ assertCollectionType(CollectionType exp, AttributesConfig::Attribute::Collection
EXPECT_EQUAL(exp.createIfNonExistant(), out.collectionType().createIfNonExistant());
}
+void
+expect_distance_metric(AttributesConfig::Attribute::Distancemetric in_metric,
+ DistanceMetric out_metric)
+{
+ AttributesConfig::Attribute a;
+ a.distancemetric = in_metric;
+ auto out = ConfigConverter::convert(a);
+ EXPECT_TRUE(out.distance_metric() == out_metric);
+}
+
TEST("require that config can be converted")
{
@@ -254,16 +264,13 @@ TEST("require that config can be converted")
EXPECT_TRUE(out.distance_metric() == DistanceMetric::Euclidean);
}
{ // distance metric (explicit)
- CACA a;
- a.distancemetric = AttributesConfig::Attribute::Distancemetric::GEODEGREES;
- auto out = ConfigConverter::convert(a);
- EXPECT_TRUE(out.distance_metric() == DistanceMetric::GeoDegrees);
- }
- { // distance metric (explicit)
- CACA a;
- a.distancemetric = AttributesConfig::Attribute::Distancemetric::INNERPRODUCT;
- auto out = ConfigConverter::convert(a);
- EXPECT_TRUE(out.distance_metric() == DistanceMetric::InnerProduct);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::EUCLIDEAN, DistanceMetric::Euclidean);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::ANGULAR, DistanceMetric::Angular);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::GEODEGREES, DistanceMetric::GeoDegrees);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::HAMMING, DistanceMetric::Hamming);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::INNERPRODUCT, DistanceMetric::InnerProduct);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::PRENORMALIZED_ANGULAR, DistanceMetric::PrenormalizedAngular);
+ expect_distance_metric(AttributesConfig::Attribute::Distancemetric::DOTPRODUCT, DistanceMetric::Dotproduct);
}
{ // hnsw index default params (enabled)
CACA a;
diff --git a/searchlib/src/vespa/searchcommon/attribute/distance_metric.h b/searchlib/src/vespa/searchcommon/attribute/distance_metric.h
index c157f6abb28..9f9f45810b9 100644
--- a/searchlib/src/vespa/searchcommon/attribute/distance_metric.h
+++ b/searchlib/src/vespa/searchcommon/attribute/distance_metric.h
@@ -4,6 +4,6 @@
namespace search::attribute {
-enum class DistanceMetric { Euclidean, Angular, GeoDegrees, InnerProduct, Hamming, PrenormalizedAngular, TransformedMips };
+enum class DistanceMetric { Euclidean, Angular, GeoDegrees, InnerProduct, Hamming, PrenormalizedAngular, Dotproduct };
}
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp
index 0edab90f089..122c2c0c55e 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_header.cpp
@@ -29,7 +29,7 @@ const vespalib::string angular = "angular";
const vespalib::string geodegrees = "geodegrees";
const vespalib::string innerproduct = "innerproduct";
const vespalib::string prenormalized_angular = "prenormalized_angular";
-const vespalib::string transformed_mips = "transformed_mips";
+const vespalib::string dotproduct = "dotproduct";
const vespalib::string hamming = "hamming";
const vespalib::string doc_id_limit_tag = "docIdLimit";
const vespalib::string enumerated_tag = "enumerated";
@@ -104,7 +104,7 @@ to_string(DistanceMetric metric)
case DistanceMetric::InnerProduct: return innerproduct;
case DistanceMetric::Hamming: return hamming;
case DistanceMetric::PrenormalizedAngular: return prenormalized_angular;
- case DistanceMetric::TransformedMips: return transformed_mips;
+ case DistanceMetric::Dotproduct: return dotproduct;
}
throw vespalib::IllegalArgumentException("Unknown distance metric " + std::to_string(static_cast<int>(metric)));
}
@@ -122,8 +122,8 @@ to_distance_metric(const vespalib::string& metric)
return DistanceMetric::InnerProduct;
} else if (metric == prenormalized_angular) {
return DistanceMetric::PrenormalizedAngular;
- } else if (metric == transformed_mips) {
- return DistanceMetric::TransformedMips;
+ } else if (metric == dotproduct) {
+ return DistanceMetric::Dotproduct;
} else if (metric == hamming) {
return DistanceMetric::Hamming;
} else {
diff --git a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp
index 2119f441a14..7f04efd940b 100644
--- a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp
@@ -136,10 +136,9 @@ ConfigConverter::convert(const AttributesConfig::Attribute & cfg)
break;
case CfgDm::PRENORMALIZED_ANGULAR:
dm = DistanceMetric::PrenormalizedAngular;
- /*
- case CfgDm::TRANSFORMED_MIPS:
- dm = DistanceMetric::TransformedMips;
- */
+ break;
+ case CfgDm::DOTPRODUCT:
+ dm = DistanceMetric::Dotproduct;
break;
}
retval.set_distance_metric(dm);
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index a338bf85e43..68988ef6308 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -39,7 +39,7 @@ make_distance_function_factory(search::attribute::DistanceMetric variant,
case CellType::DOUBLE: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>();
default: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>();
}
- case DistanceMetric::TransformedMips:
+ case DistanceMetric::Dotproduct:
switch (cell_type) {
case CellType::DOUBLE: return std::make_unique<MipsDistanceFunctionFactory<double>>();
case CellType::INT8: return std::make_unique<MipsDistanceFunctionFactory<vespalib::eval::Int8Float>>();