From 354d69153e8feb26316604693aebe672289e3bea Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 26 Apr 2023 13:45:54 +0000 Subject: avoid looping twice for calc_with_limit with BFloat16 cells --- .../src/vespa/searchlib/tensor/distance_function_factory.cpp | 1 + searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index c088d498f0f..42a1b7b01e8 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -103,6 +103,7 @@ make_distance_function_factory(search::attribute::DistanceMetric variant, switch (cell_type) { case CellType::DOUBLE: return std::make_unique>(); case CellType::INT8: return std::make_unique>(); + case CellType::BFLOAT16: return std::make_unique>(); default: return std::make_unique>(); } case DistanceMetric::InnerProduct: diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp index 7995c87d055..a98b37cb6cc 100644 --- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp @@ -50,9 +50,11 @@ template class SquaredEuclideanDistanceHW; template class SquaredEuclideanDistanceHW; using vespalib::eval::Int8Float; +using vespalib::BFloat16; -template +template class BoundEuclideanDistance : public BoundDistanceFunction { + using FloatType = std::conditional::value,float,AttributeCellType>::type; private: const vespalib::hwaccelrated::IAccelrated & _computer; mutable TemporaryVectorStore _tmpSpace; @@ -83,7 +85,7 @@ public: return score; } double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override { - vespalib::ConstArrayRef rhs_vector = _tmpSpace.convertRhs(rhs); + vespalib::ConstArrayRef rhs_vector = rhs.typify(); double sum = 0.0; size_t sz = _lhs_vector.size(); assert(sz == rhs_vector.size()); @@ -96,6 +98,7 @@ public: }; template class BoundEuclideanDistance; +template class BoundEuclideanDistance; template class BoundEuclideanDistance; template class BoundEuclideanDistance; @@ -114,6 +117,7 @@ EuclideanDistanceFunctionFactory::for_insertion_vector(const vespalib } template class EuclideanDistanceFunctionFactory; +template class EuclideanDistanceFunctionFactory; template class EuclideanDistanceFunctionFactory; template class EuclideanDistanceFunctionFactory; -- cgit v1.2.3