summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp')
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp
index 15d6040a11a..04a2fa1cf2f 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_benchmark.cpp
@@ -58,12 +58,12 @@ void benchmark(size_t iterations, size_t elems) __attribute_noinline__;
template<typename T>
void benchmark(size_t iterations, size_t elems, const DistanceFunctionFactory & df) {
std::vector<T> av, bv;
- srand(7);
+ srandom(7);
av.reserve(elems);
bv.reserve(elems);
for (size_t i(0); i < elems; i++) {
- av.push_back(rand());
- bv.push_back(rand());
+ av.push_back(random()%128);
+ bv.push_back(random()%128);
}
TypedCells a_cells(av), b_cells(bv);
@@ -78,17 +78,17 @@ void benchmark(size_t iterations, size_t elems, const std::string & dist_functio
benchmark<T>(iterations, elems, EuclideanDistanceFunctionFactory<T>());
}
if (dist_functions.find("angular") != npos) {
- if (std::is_same<T, double>() || std::is_same<T, float>()) {
+ if ( ! std::is_same<T, BFloat16>()) {
benchmark<T>(iterations, elems, AngularDistanceFunctionFactory<T>());
}
}
if (dist_functions.find("prenorm") != npos) {
- if (std::is_same<T, double>() || std::is_same<T, float>()) {
+ if ( ! std::is_same<T, BFloat16>()) {
benchmark<T>(iterations, elems, PrenormalizedAngularDistanceFunctionFactory<T>());
}
}
if (dist_functions.find("mips") != npos) {
- if (std::is_same<T, double>() || std::is_same<T, float>() || std::is_same<T, Int8Float>()) {
+ if ( !std::is_same<T, BFloat16>()) {
benchmark<T>(iterations, elems, MipsDistanceFunctionFactory<T>());
}
}