summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-06-13 10:32:54 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-06-13 10:32:54 +0000
commit20843210c453422350d42652d4a21ef94cce5ebe (patch)
tree3977bf48eebb754802147c6e982e274a29f7e02f /eval
parent5233b53188a40bfc51f38990173d5d906534499a (diff)
replace template magic with if statement
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp39
1 files changed, 11 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
index 9a43423f13b..9c18cf285d4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
@@ -119,32 +119,15 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio
}
}
-template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner>
-struct MyMatMulOp {
- static auto get_fun() {
- return my_matmul_op<LCT, RCT, lhs_common_inner, rhs_common_inner>;
- }
-};
-
-template <bool lhs_common_inner, bool rhs_common_inner>
-struct MyMatMulOp<double, double, lhs_common_inner, rhs_common_inner> {
- static auto get_fun() {
- return my_cblas_double_matmul_op<lhs_common_inner, rhs_common_inner>;
- }
-};
-
-template <bool lhs_common_inner, bool rhs_common_inner>
-struct MyMatMulOp<float, float, lhs_common_inner, rhs_common_inner> {
- static auto get_fun() {
- return my_cblas_float_matmul_op<lhs_common_inner, rhs_common_inner>;
- }
-};
-
-struct MyTarget {
- template<typename R1, typename R2, typename R3, typename R4>
- static auto invoke() {
- using MyOp = MyMatMulOp<R1, R2, R3::value, R4::value>;
- return MyOp::get_fun();
+struct MyGetFun {
+ template<typename R1, typename R2, typename R3, typename R4> static auto invoke() {
+ if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) {
+ return my_cblas_double_matmul_op<R3::value, R4::value>;
+ } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) {
+ return my_cblas_float_matmul_op<R3::value, R4::value>;
+ } else {
+ return my_matmul_op<R1, R2, R3::value, R4::value>;
+ }
}
};
@@ -185,9 +168,9 @@ DenseMatMulFunction::~DenseMatMulFunction() = default;
eval::InterpretedFunction::Instruction
DenseMatMulFunction::compile_self(const TensorEngine &, Stash &stash) const
{
- using MyTypify = TypifyValue<eval::TypifyCellType,vespalib::TypifyBool>;
+ using MyTypify = TypifyValue<eval::TypifyCellType,TypifyBool>;
Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size);
- auto op = typify_invoke<4,MyTypify,MyTarget>(
+ auto op = typify_invoke<4,MyTypify,MyGetFun>(
lhs().result_type().cell_type(), rhs().result_type().cell_type(),
_lhs_common_inner, _rhs_common_inner);
return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self));