aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-05-29 12:05:49 +0000
committerHenning Baldersheim <balder@yahoo-inc.com>2024-05-29 12:05:49 +0000
commit81f623fc5bbf4f1fa49c87bacb1f9c6dee3f41b2 (patch)
treef37d6a363244963134a6ee07c61ef6b3e68f33f0 /eval
parent8276854275b3daaf79d6774062bce6279981c4ba (diff)
Use optimized bfloat16 to float conversion.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/instruction/generic_cell_cast.cpp18
1 files changed, 17 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp
index 16f47a56222..2ee0245a978 100644
--- a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp
+++ b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp
@@ -5,6 +5,7 @@
#include <vespa/eval/eval/wrap_param.h>
#include <vespa/vespalib/util/stash.h>
#include <vespa/vespalib/util/typify.h>
+#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <cassert>
using namespace vespalib::eval::tensor_function;
@@ -16,6 +17,8 @@ using Instruction = InterpretedFunction::Instruction;
namespace {
+using hwaccelrated::IAccelrated;
+
template <typename ICT, typename OCT>
void my_generic_cell_cast_op(State &state, uint64_t param_in) {
const auto &res_type = unwrap_param<ValueType>(param_in);
@@ -31,6 +34,19 @@ void my_generic_cell_cast_op(State &state, uint64_t param_in) {
state.pop_push(result_ref);
}
+template <>
+void my_generic_cell_cast_op<BFloat16, float>(State &state, uint64_t param_in) {
+ const auto &res_type = unwrap_param<ValueType>(param_in);
+ const Value &a = state.peek(0);
+ auto input_cells = a.cells().typify<BFloat16>();
+ auto output_cells = state.stash.create_uninitialized_array<float>(input_cells.size());
+ static const IAccelrated & accelrator = IAccelrated::getAccelerator();
+ accelrator.convert_bfloat16_to_float(reinterpret_cast<const uint16_t *>(input_cells.begin()),
+ output_cells.data(), output_cells.size());
+ Value &result_ref = state.stash.create<ValueView>(res_type, a.index(), TypedCells(output_cells));
+ state.pop_push(result_ref);
+}
+
struct SelectGenericCellCastOp {
template <typename ICT, typename OCT>
static InterpretedFunction::op_function invoke() {
@@ -60,7 +76,7 @@ GenericCellCast::make_instruction(const ValueType &result_type,
assert(!input_type.is_double());
auto &param = stash.create<ValueType>(result_type);
auto op = typify_invoke<2,TypifyCellType,SelectGenericCellCastOp>(from, to);
- return Instruction(op, wrap_param<ValueType>(param));
+ return {op, wrap_param<ValueType>(param)};
}
}