diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-04-26 13:45:54 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-04-26 13:45:54 +0000 |
commit | 354d69153e8feb26316604693aebe672289e3bea (patch) | |
tree | d80f28b0cc5813b0ff4c2642f23005682ab5cd60 /searchlib | |
parent | 49aba5432a2965b9ab4792e4b38445f6f3289074 (diff) |
avoid looping twice for calc_with_limit with BFloat16 cells
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp | 1 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp | 8 |
2 files changed, 7 insertions, 2 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index c088d498f0f..42a1b7b01e8 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -103,6 +103,7 @@ make_distance_function_factory(search::attribute::DistanceMetric variant, switch (cell_type) { case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>(); case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>(); + case CellType::BFLOAT16: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::BFloat16>>(); default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>(); } case DistanceMetric::InnerProduct: diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp index 7995c87d055..a98b37cb6cc 100644 --- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp @@ -50,9 +50,11 @@ template class SquaredEuclideanDistanceHW<float>; template class SquaredEuclideanDistanceHW<double>; using vespalib::eval::Int8Float; +using vespalib::BFloat16; -template<typename FloatType> +template<typename AttributeCellType> class BoundEuclideanDistance : public BoundDistanceFunction { + using FloatType = std::conditional<std::is_same<AttributeCellType,BFloat16>::value,float,AttributeCellType>::type; private: const vespalib::hwaccelrated::IAccelrated & _computer; mutable TemporaryVectorStore<FloatType> _tmpSpace; @@ -83,7 +85,7 @@ public: return score; } double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override { - vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); + vespalib::ConstArrayRef<AttributeCellType> rhs_vector = rhs.typify<AttributeCellType>(); double sum = 0.0; size_t sz = _lhs_vector.size(); assert(sz == rhs_vector.size()); @@ -96,6 +98,7 @@ public: }; template class BoundEuclideanDistance<Int8Float>; +template class BoundEuclideanDistance<BFloat16>; template class BoundEuclideanDistance<float>; template class BoundEuclideanDistance<double>; @@ -114,6 +117,7 @@ EuclideanDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib } template class EuclideanDistanceFunctionFactory<Int8Float>; +template class EuclideanDistanceFunctionFactory<BFloat16>; template class EuclideanDistanceFunctionFactory<float>; template class EuclideanDistanceFunctionFactory<double>; |