diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-06-13 10:32:54 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-06-13 10:32:54 +0000 |
commit | 20843210c453422350d42652d4a21ef94cce5ebe (patch) | |
tree | 3977bf48eebb754802147c6e982e274a29f7e02f /eval/src | |
parent | 5233b53188a40bfc51f38990173d5d906534499a (diff) |
replace template magic with if statement
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp | 39 |
1 files changed, 11 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp index 9a43423f13b..9c18cf285d4 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp @@ -119,32 +119,15 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio } } -template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner> -struct MyMatMulOp { - static auto get_fun() { - return my_matmul_op<LCT, RCT, lhs_common_inner, rhs_common_inner>; - } -}; - -template <bool lhs_common_inner, bool rhs_common_inner> -struct MyMatMulOp<double, double, lhs_common_inner, rhs_common_inner> { - static auto get_fun() { - return my_cblas_double_matmul_op<lhs_common_inner, rhs_common_inner>; - } -}; - -template <bool lhs_common_inner, bool rhs_common_inner> -struct MyMatMulOp<float, float, lhs_common_inner, rhs_common_inner> { - static auto get_fun() { - return my_cblas_float_matmul_op<lhs_common_inner, rhs_common_inner>; - } -}; - -struct MyTarget { - template<typename R1, typename R2, typename R3, typename R4> - static auto invoke() { - using MyOp = MyMatMulOp<R1, R2, R3::value, R4::value>; - return MyOp::get_fun(); +struct MyGetFun { + template<typename R1, typename R2, typename R3, typename R4> static auto invoke() { + if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) { + return my_cblas_double_matmul_op<R3::value, R4::value>; + } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) { + return my_cblas_float_matmul_op<R3::value, R4::value>; + } else { + return my_matmul_op<R1, R2, R3::value, R4::value>; + } } }; @@ -185,9 +168,9 @@ DenseMatMulFunction::~DenseMatMulFunction() = default; eval::InterpretedFunction::Instruction DenseMatMulFunction::compile_self(const TensorEngine &, Stash &stash) const { - using MyTypify = TypifyValue<eval::TypifyCellType,vespalib::TypifyBool>; + using MyTypify = TypifyValue<eval::TypifyCellType,TypifyBool>; Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size); - auto op = typify_invoke<4,MyTypify,MyTarget>( + auto op = typify_invoke<4,MyTypify,MyGetFun>( lhs().result_type().cell_type(), rhs().result_type().cell_type(), _lhs_common_inner, _rhs_common_inner); return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self)); |