aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-04-26 13:45:54 +0000
committerArne Juul <arnej@yahooinc.com>2023-04-26 13:45:54 +0000
commit354d69153e8feb26316604693aebe672289e3bea (patch)
treed80f28b0cc5813b0ff4c2642f23005682ab5cd60 /searchlib
parent49aba5432a2965b9ab4792e4b38445f6f3289074 (diff)
avoid looping twice for calc_with_limit with BFloat16 cells
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp8
2 files changed, 7 insertions, 2 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index c088d498f0f..42a1b7b01e8 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -103,6 +103,7 @@ make_distance_function_factory(search::attribute::DistanceMetric variant,
switch (cell_type) {
case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ case CellType::BFLOAT16: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::BFloat16>>();
default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
}
case DistanceMetric::InnerProduct:
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
index 7995c87d055..a98b37cb6cc 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
@@ -50,9 +50,11 @@ template class SquaredEuclideanDistanceHW<float>;
template class SquaredEuclideanDistanceHW<double>;
using vespalib::eval::Int8Float;
+using vespalib::BFloat16;
-template<typename FloatType>
+template<typename AttributeCellType>
class BoundEuclideanDistance : public BoundDistanceFunction {
+ using FloatType = std::conditional<std::is_same<AttributeCellType,BFloat16>::value,float,AttributeCellType>::type;
private:
const vespalib::hwaccelrated::IAccelrated & _computer;
mutable TemporaryVectorStore<FloatType> _tmpSpace;
@@ -83,7 +85,7 @@ public:
return score;
}
double calc_with_limit(const vespalib::eval::TypedCells& rhs, double limit) const override {
- vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ vespalib::ConstArrayRef<AttributeCellType> rhs_vector = rhs.typify<AttributeCellType>();
double sum = 0.0;
size_t sz = _lhs_vector.size();
assert(sz == rhs_vector.size());
@@ -96,6 +98,7 @@ public:
};
template class BoundEuclideanDistance<Int8Float>;
+template class BoundEuclideanDistance<BFloat16>;
template class BoundEuclideanDistance<float>;
template class BoundEuclideanDistance<double>;
@@ -114,6 +117,7 @@ EuclideanDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib
}
template class EuclideanDistanceFunctionFactory<Int8Float>;
+template class EuclideanDistanceFunctionFactory<BFloat16>;
template class EuclideanDistanceFunctionFactory<float>;
template class EuclideanDistanceFunctionFactory<double>;