summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2021-11-25 21:21:46 +0000
committerHenning Baldersheim <balder@yahoo-inc.com>2021-11-26 05:47:48 +0000
commitb54f9353181518054a1aaafc294df03ee15d58de (patch)
tree0e6146a540beaa66e51b70229194dcb15b944cd9 /searchlib
parent727e8c771e96d2d596b48b558baa0bf8fa5b4ab2 (diff)
- Use the optimized int8_t euclidian distance calculation.
- If types are not identical, fall back to general implementation.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.h9
2 files changed, 8 insertions, 2 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index 315d4c8535c..96dfc580d87 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -21,6 +21,7 @@ make_distance_function(DistanceMetric variant, CellType cell_type)
switch (cell_type) {
case CellType::FLOAT: return std::make_unique<SquaredEuclideanDistanceHW<float>>();
case CellType::DOUBLE: return std::make_unique<SquaredEuclideanDistanceHW<double>>();
+ case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>();
default: return std::make_unique<SquaredEuclideanDistance>(CellType::FLOAT);
}
case DistanceMetric::Angular:
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
index 517ef68511b..df6fe4a6df4 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h
@@ -44,14 +44,19 @@ public:
assert(expected_cell_type() == vespalib::eval::get_cell_type<FloatType>());
}
+ static const double *cast(const double * p) { return p; }
+ static const float *cast(const float * p) { return p; }
+ static const int8_t *cast(const vespalib::eval::Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
double calc(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs) const override {
constexpr vespalib::eval::CellType expected = vespalib::eval::get_cell_type<FloatType>();
- assert(lhs.type == expected && rhs.type == expected);
+ if ((lhs.type != expected) || (rhs.type == expected)) {
+ return SquaredEuclideanDistance::calc(lhs, rhs);
+ }
auto lhs_vector = lhs.typify<FloatType>();
auto rhs_vector = rhs.typify<FloatType>();
size_t sz = lhs_vector.size();
assert(sz == rhs_vector.size());
- return _computer.squaredEuclideanDistance(&lhs_vector[0], &rhs_vector[0], sz);
+ return _computer.squaredEuclideanDistance(cast(&lhs_vector[0]), cast(&rhs_vector[0]), sz);
}
double calc_with_limit(const vespalib::eval::TypedCells& lhs,