summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-05-14 23:04:19 +0000
committerHenning Baldersheim <balder@yahoo-inc.com>2024-05-14 23:04:19 +0000
commit51dd79b028db920f0749dd183200455f2f7a1f71 (patch)
treeaba8e53c1d17ce107a0d9719d63515d2896dd116 /searchlib
parentcf84c1de017cc9e3cfd1b8859ddfbfba41a350e5 (diff)
Speed up bfloat16 to float conversion
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp21
1 files changed, 17 insertions, 4 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
index 4753e9d7c87..5d29f38cf2a 100644
--- a/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/temporary_vector_store.cpp
@@ -1,11 +1,13 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "temporary_vector_store.h"
+#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
using vespalib::ConstArrayRef;
using vespalib::ArrayRef;
using vespalib::eval::CellType;
using vespalib::eval::TypedCells;
+using vespalib::hwaccelrated::IAccelrated;
namespace search::tensor {
@@ -13,18 +15,29 @@ namespace {
template<typename FromType, typename ToType>
ConstArrayRef<ToType>
+convert_cells(ArrayRef<ToType> space, TypedCells cells) noexcept __attribute_noinline__;
+
+template<typename FromType, typename ToType>
+ConstArrayRef<ToType>
convert_cells(ArrayRef<ToType> space, TypedCells cells) noexcept
{
- assert(cells.size == space.size());
- auto old_cells = cells.typify<FromType>();
+ auto old_cells = cells.unsafe_typify<FromType>();
ToType *p = space.data();
for (FromType value : old_cells) {
- ToType conv(value);
- *p++ = conv;
+ *p++ = value;
}
return space;
}
+template<>
+ConstArrayRef<float>
+convert_cells<vespalib::BFloat16, float>(ArrayRef<float> space, TypedCells cells) noexcept
+{
+ static const IAccelrated & accelrator = IAccelrated::getAccelerator();
+ accelrator.convert_bfloat16_to_float(reinterpret_cast<const uint16_t *>(cells.data), space.data(), space.size());
+ return space;
+}
+
template <typename ToType>
struct ConvertCellsSelector
{