diff options
Diffstat (limited to 'eval/src/vespa/eval')
-rw-r--r-- | eval/src/vespa/eval/eval/gbdt.cpp | 3 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/int8float.h | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/generic_cell_cast.cpp | 18 |
3 files changed, 21 insertions, 2 deletions
diff --git a/eval/src/vespa/eval/eval/gbdt.cpp b/eval/src/vespa/eval/eval/gbdt.cpp index 3422228b03c..7ab4c4ae822 100644 --- a/eval/src/vespa/eval/eval/gbdt.cpp +++ b/eval/src/vespa/eval/eval/gbdt.cpp @@ -154,6 +154,9 @@ Optimize::select_best(const ForestStats &stats, if ((stats.tree_sizes.back().size > 12) && (path_len > 2500.0)) { return apply_chain(VMForest::optimize_chain, stats, trees); } + if (stats.total_size > 25000) { + return apply_chain(VMForest::optimize_chain, stats, trees); + } return Optimize::Result(); } diff --git a/eval/src/vespa/eval/eval/int8float.h b/eval/src/vespa/eval/eval/int8float.h index b751b2eb8ad..d28446a8524 100644 --- a/eval/src/vespa/eval/eval/int8float.h +++ b/eval/src/vespa/eval/eval/int8float.h @@ -32,7 +32,7 @@ public: constexpr float to_float() const noexcept { return _bits; } constexpr void assign(float value) noexcept { _bits = value; } - constexpr int8_t get_bits() const { return _bits; } + constexpr int8_t get_bits() const noexcept { return _bits; } constexpr void assign_bits(int8_t value) noexcept { _bits = value; } }; diff --git a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp index 16f47a56222..2ee0245a978 100644 --- a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp +++ b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp @@ -5,6 +5,7 @@ #include <vespa/eval/eval/wrap_param.h> #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/hwaccelrated/iaccelrated.h> #include <cassert> using namespace vespalib::eval::tensor_function; @@ -16,6 +17,8 @@ using Instruction = InterpretedFunction::Instruction; namespace { +using hwaccelrated::IAccelrated; + template <typename ICT, typename OCT> void my_generic_cell_cast_op(State &state, uint64_t param_in) { const auto &res_type = unwrap_param<ValueType>(param_in); @@ -31,6 +34,19 @@ void my_generic_cell_cast_op(State &state, uint64_t param_in) { state.pop_push(result_ref); } +template <> +void my_generic_cell_cast_op<BFloat16, float>(State &state, uint64_t param_in) { + const auto &res_type = unwrap_param<ValueType>(param_in); + const Value &a = state.peek(0); + auto input_cells = a.cells().typify<BFloat16>(); + auto output_cells = state.stash.create_uninitialized_array<float>(input_cells.size()); + static const IAccelrated & accelrator = IAccelrated::getAccelerator(); + accelrator.convert_bfloat16_to_float(reinterpret_cast<const uint16_t *>(input_cells.begin()), + output_cells.data(), output_cells.size()); + Value &result_ref = state.stash.create<ValueView>(res_type, a.index(), TypedCells(output_cells)); + state.pop_push(result_ref); +} + struct SelectGenericCellCastOp { template <typename ICT, typename OCT> static InterpretedFunction::op_function invoke() { @@ -60,7 +76,7 @@ GenericCellCast::make_instruction(const ValueType &result_type, assert(!input_type.is_double()); auto ¶m = stash.create<ValueType>(result_type); auto op = typify_invoke<2,TypifyCellType,SelectGenericCellCastOp>(from, to); - return Instruction(op, wrap_param<ValueType>(param)); + return {op, wrap_param<ValueType>(param)}; } } |