diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-06-12 13:15:36 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-06-12 13:50:27 +0000 |
commit | 5233b53188a40bfc51f38990173d5d906534499a (patch) | |
tree | 30a75ccee0a98aeea2995533bb2e7804581fc6eb /eval | |
parent | 015d3f1afd813dd738432c017db0644882dd30de (diff) |
use typify_invoke to select matmul implementation
* instead of a chain of select functions, use typify_invoke;
use partial specialization on a struct to handle the
double*double and float*float cases specially.
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp | 76 |
1 files changed, 33 insertions, 43 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp index 695e0fddd08..9a43423f13b 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp @@ -80,47 +80,6 @@ void my_cblas_float_matmul_op(eval::InterpretedFunction::State &state, uint64_t state.pop_pop_push(state.stash.create<DenseTensorView>(self.result_type, TypedCells(dst_cells))); } -template <bool lhs_common_inner, bool rhs_common_inner> -struct MyMatMulOp { - template <typename LCT, typename RCT> - static auto get_fun() { return my_matmul_op<LCT,RCT,lhs_common_inner,rhs_common_inner>; } -}; - -template <bool lhs_common_inner, bool rhs_common_inner> -eval::InterpretedFunction::op_function my_select3(CellType lct, CellType rct) -{ - if (lct == rct) { - if (lct == ValueType::CellType::DOUBLE) { - return my_cblas_double_matmul_op<lhs_common_inner,rhs_common_inner>; - } - if (lct == ValueType::CellType::FLOAT) { - return my_cblas_float_matmul_op<lhs_common_inner,rhs_common_inner>; - } - } - return select_2<MyMatMulOp<lhs_common_inner,rhs_common_inner>>(lct, rct); -} - -template <bool lhs_common_inner> -eval::InterpretedFunction::op_function my_select2(CellType lct, CellType rct, - bool rhs_common_inner) -{ - if (rhs_common_inner) { - return my_select3<lhs_common_inner,true>(lct, rct); - } else { - return my_select3<lhs_common_inner,false>(lct, rct); - } -} - -eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct, - bool lhs_common_inner, bool rhs_common_inner) -{ - if (lhs_common_inner) { - return my_select2<true>(lct, rct, rhs_common_inner); - } else { - return my_select2<false>(lct, rct, rhs_common_inner); - } -} - bool is_matrix(const ValueType &type) { return (type.is_dense() && (type.dimensions().size() == 2)); } @@ -160,6 +119,35 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio } } +template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner> +struct MyMatMulOp { + static auto get_fun() { + return my_matmul_op<LCT, RCT, lhs_common_inner, rhs_common_inner>; + } +}; + +template <bool lhs_common_inner, bool rhs_common_inner> +struct MyMatMulOp<double, double, lhs_common_inner, rhs_common_inner> { + static auto get_fun() { + return my_cblas_double_matmul_op<lhs_common_inner, rhs_common_inner>; + } +}; + +template <bool lhs_common_inner, bool rhs_common_inner> +struct MyMatMulOp<float, float, lhs_common_inner, rhs_common_inner> { + static auto get_fun() { + return my_cblas_float_matmul_op<lhs_common_inner, rhs_common_inner>; + } +}; + +struct MyTarget { + template<typename R1, typename R2, typename R3, typename R4> + static auto invoke() { + using MyOp = MyMatMulOp<R1, R2, R3::value, R4::value>; + return MyOp::get_fun(); + } +}; + } // namespace vespalib::tensor::<unnamed> DenseMatMulFunction::Self::Self(const eval::ValueType &result_type_in, @@ -197,9 +185,11 @@ DenseMatMulFunction::~DenseMatMulFunction() = default; eval::InterpretedFunction::Instruction DenseMatMulFunction::compile_self(const TensorEngine &, Stash &stash) const { + using MyTypify = TypifyValue<eval::TypifyCellType,vespalib::TypifyBool>; Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size); - auto op = my_select(lhs().result_type().cell_type(), rhs().result_type().cell_type(), - _lhs_common_inner, _rhs_common_inner); + auto op = typify_invoke<4,MyTypify,MyTarget>( + lhs().result_type().cell_type(), rhs().result_type().cell_type(), + _lhs_common_inner, _rhs_common_inner); return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self)); } |