diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2021-11-25 21:21:46 +0000 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2021-11-26 05:47:48 +0000 |
commit | b54f9353181518054a1aaafc294df03ee15d58de (patch) | |
tree | 0e6146a540beaa66e51b70229194dcb15b944cd9 /searchlib | |
parent | 727e8c771e96d2d596b48b558baa0bf8fa5b4ab2 (diff) |
- Use the optimized int8_t euclidian distance calculation.
- If types are not identical, fall back to general implementation.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp | 1 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/euclidean_distance.h | 9 |
2 files changed, 8 insertions, 2 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 315d4c8535c..96dfc580d87 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -21,6 +21,7 @@ make_distance_function(DistanceMetric variant, CellType cell_type) switch (cell_type) { case CellType::FLOAT: return std::make_unique<SquaredEuclideanDistanceHW<float>>(); case CellType::DOUBLE: return std::make_unique<SquaredEuclideanDistanceHW<double>>(); + case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>(); default: return std::make_unique<SquaredEuclideanDistance>(CellType::FLOAT); } case DistanceMetric::Angular: diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h index 517ef68511b..df6fe4a6df4 100644 --- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h @@ -44,14 +44,19 @@ public: assert(expected_cell_type() == vespalib::eval::get_cell_type<FloatType>()); } + static const double *cast(const double * p) { return p; } + static const float *cast(const float * p) { return p; } + static const int8_t *cast(const vespalib::eval::Int8Float * p) { return reinterpret_cast<const int8_t *>(p); } double calc(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs) const override { constexpr vespalib::eval::CellType expected = vespalib::eval::get_cell_type<FloatType>(); - assert(lhs.type == expected && rhs.type == expected); + if ((lhs.type != expected) || (rhs.type == expected)) { + return SquaredEuclideanDistance::calc(lhs, rhs); + } auto lhs_vector = lhs.typify<FloatType>(); auto rhs_vector = rhs.typify<FloatType>(); size_t sz = lhs_vector.size(); assert(sz == rhs_vector.size()); - return _computer.squaredEuclideanDistance(&lhs_vector[0], &rhs_vector[0], sz); + return _computer.squaredEuclideanDistance(cast(&lhs_vector[0]), cast(&rhs_vector[0]), sz); } double calc_with_limit(const vespalib::eval::TypedCells& lhs, |