diff options
author | HÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com> | 2020-06-13 15:23:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-13 15:23:38 +0200 |
commit | 24a8b542e20c8758f2d5973fc21e979b80247dae (patch) | |
tree | 2000e83e615a1ebf1f526bc82a3663a6d2da1612 /eval | |
parent | bc6fa135488141f9362896c8e7f4e308e47a6591 (diff) | |
parent | 20843210c453422350d42652d4a21ef94cce5ebe (diff) |
Merge pull request #13571 from vespa-engine/arnej/use-typify-invoke-for-matmul
use typify_invoke to select matmul implementation
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp | 59 |
1 files changed, 16 insertions, 43 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp index 695e0fddd08..9c18cf285d4 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp @@ -80,47 +80,6 @@ void my_cblas_float_matmul_op(eval::InterpretedFunction::State &state, uint64_t state.pop_pop_push(state.stash.create<DenseTensorView>(self.result_type, TypedCells(dst_cells))); } -template <bool lhs_common_inner, bool rhs_common_inner> -struct MyMatMulOp { - template <typename LCT, typename RCT> - static auto get_fun() { return my_matmul_op<LCT,RCT,lhs_common_inner,rhs_common_inner>; } -}; - -template <bool lhs_common_inner, bool rhs_common_inner> -eval::InterpretedFunction::op_function my_select3(CellType lct, CellType rct) -{ - if (lct == rct) { - if (lct == ValueType::CellType::DOUBLE) { - return my_cblas_double_matmul_op<lhs_common_inner,rhs_common_inner>; - } - if (lct == ValueType::CellType::FLOAT) { - return my_cblas_float_matmul_op<lhs_common_inner,rhs_common_inner>; - } - } - return select_2<MyMatMulOp<lhs_common_inner,rhs_common_inner>>(lct, rct); -} - -template <bool lhs_common_inner> -eval::InterpretedFunction::op_function my_select2(CellType lct, CellType rct, - bool rhs_common_inner) -{ - if (rhs_common_inner) { - return my_select3<lhs_common_inner,true>(lct, rct); - } else { - return my_select3<lhs_common_inner,false>(lct, rct); - } -} - -eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct, - bool lhs_common_inner, bool rhs_common_inner) -{ - if (lhs_common_inner) { - return my_select2<true>(lct, rct, rhs_common_inner); - } else { - return my_select2<false>(lct, rct, rhs_common_inner); - } -} - bool is_matrix(const ValueType &type) { return (type.is_dense() && (type.dimensions().size() == 2)); } @@ -160,6 +119,18 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio } } +struct MyGetFun { + template<typename R1, typename R2, typename R3, typename R4> static auto invoke() { + if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) { + return my_cblas_double_matmul_op<R3::value, R4::value>; + } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) { + return my_cblas_float_matmul_op<R3::value, R4::value>; + } else { + return my_matmul_op<R1, R2, R3::value, R4::value>; + } + } +}; + } // namespace vespalib::tensor::<unnamed> DenseMatMulFunction::Self::Self(const eval::ValueType &result_type_in, @@ -197,9 +168,11 @@ DenseMatMulFunction::~DenseMatMulFunction() = default; eval::InterpretedFunction::Instruction DenseMatMulFunction::compile_self(const TensorEngine &, Stash &stash) const { + using MyTypify = TypifyValue<eval::TypifyCellType,TypifyBool>; Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size); - auto op = my_select(lhs().result_type().cell_type(), rhs().result_type().cell_type(), - _lhs_common_inner, _rhs_common_inner); + auto op = typify_invoke<4,MyTypify,MyGetFun>( + lhs().result_type().cell_type(), rhs().result_type().cell_type(), + _lhs_common_inner, _rhs_common_inner); return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self)); } |