diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-05-29 12:05:49 +0000 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-05-29 12:05:49 +0000 |
commit | 81f623fc5bbf4f1fa49c87bacb1f9c6dee3f41b2 (patch) | |
tree | f37d6a363244963134a6ee07c61ef6b3e68f33f0 /eval | |
parent | 8276854275b3daaf79d6774062bce6279981c4ba (diff) |
Use optimized bfloat16 to float conversion.
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_cell_cast.cpp | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp index 16f47a56222..2ee0245a978 100644 --- a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp +++ b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp @@ -5,6 +5,7 @@ #include <vespa/eval/eval/wrap_param.h> #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/hwaccelrated/iaccelrated.h> #include <cassert> using namespace vespalib::eval::tensor_function; @@ -16,6 +17,8 @@ using Instruction = InterpretedFunction::Instruction; namespace { +using hwaccelrated::IAccelrated; + template <typename ICT, typename OCT> void my_generic_cell_cast_op(State &state, uint64_t param_in) { const auto &res_type = unwrap_param<ValueType>(param_in); @@ -31,6 +34,19 @@ void my_generic_cell_cast_op(State &state, uint64_t param_in) { state.pop_push(result_ref); } +template <> +void my_generic_cell_cast_op<BFloat16, float>(State &state, uint64_t param_in) { + const auto &res_type = unwrap_param<ValueType>(param_in); + const Value &a = state.peek(0); + auto input_cells = a.cells().typify<BFloat16>(); + auto output_cells = state.stash.create_uninitialized_array<float>(input_cells.size()); + static const IAccelrated & accelrator = IAccelrated::getAccelerator(); + accelrator.convert_bfloat16_to_float(reinterpret_cast<const uint16_t *>(input_cells.begin()), + output_cells.data(), output_cells.size()); + Value &result_ref = state.stash.create<ValueView>(res_type, a.index(), TypedCells(output_cells)); + state.pop_push(result_ref); +} + struct SelectGenericCellCastOp { template <typename ICT, typename OCT> static InterpretedFunction::op_function invoke() { @@ -60,7 +76,7 @@ GenericCellCast::make_instruction(const ValueType &result_type, assert(!input_type.is_double()); auto ¶m = stash.create<ValueType>(result_type); auto op = typify_invoke<2,TypifyCellType,SelectGenericCellCastOp>(from, to); - return Instruction(op, wrap_param<ValueType>(param)); + return {op, wrap_param<ValueType>(param)}; } } |