summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2020-06-13 15:23:38 +0200
committerGitHub <noreply@github.com>2020-06-13 15:23:38 +0200
commit24a8b542e20c8758f2d5973fc21e979b80247dae (patch)
tree2000e83e615a1ebf1f526bc82a3663a6d2da1612 /eval
parentbc6fa135488141f9362896c8e7f4e308e47a6591 (diff)
parent20843210c453422350d42652d4a21ef94cce5ebe (diff)
Merge pull request #13571 from vespa-engine/arnej/use-typify-invoke-for-matmul
use typify_invoke to select matmul implementation
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp59
1 files changed, 16 insertions, 43 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
index 695e0fddd08..9c18cf285d4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
@@ -80,47 +80,6 @@ void my_cblas_float_matmul_op(eval::InterpretedFunction::State &state, uint64_t
state.pop_pop_push(state.stash.create<DenseTensorView>(self.result_type, TypedCells(dst_cells)));
}
-template <bool lhs_common_inner, bool rhs_common_inner>
-struct MyMatMulOp {
- template <typename LCT, typename RCT>
- static auto get_fun() { return my_matmul_op<LCT,RCT,lhs_common_inner,rhs_common_inner>; }
-};
-
-template <bool lhs_common_inner, bool rhs_common_inner>
-eval::InterpretedFunction::op_function my_select3(CellType lct, CellType rct)
-{
- if (lct == rct) {
- if (lct == ValueType::CellType::DOUBLE) {
- return my_cblas_double_matmul_op<lhs_common_inner,rhs_common_inner>;
- }
- if (lct == ValueType::CellType::FLOAT) {
- return my_cblas_float_matmul_op<lhs_common_inner,rhs_common_inner>;
- }
- }
- return select_2<MyMatMulOp<lhs_common_inner,rhs_common_inner>>(lct, rct);
-}
-
-template <bool lhs_common_inner>
-eval::InterpretedFunction::op_function my_select2(CellType lct, CellType rct,
- bool rhs_common_inner)
-{
- if (rhs_common_inner) {
- return my_select3<lhs_common_inner,true>(lct, rct);
- } else {
- return my_select3<lhs_common_inner,false>(lct, rct);
- }
-}
-
-eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct,
- bool lhs_common_inner, bool rhs_common_inner)
-{
- if (lhs_common_inner) {
- return my_select2<true>(lct, rct, rhs_common_inner);
- } else {
- return my_select2<false>(lct, rct, rhs_common_inner);
- }
-}
-
bool is_matrix(const ValueType &type) {
return (type.is_dense() && (type.dimensions().size() == 2));
}
@@ -160,6 +119,18 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio
}
}
+struct MyGetFun {
+ template<typename R1, typename R2, typename R3, typename R4> static auto invoke() {
+ if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) {
+ return my_cblas_double_matmul_op<R3::value, R4::value>;
+ } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) {
+ return my_cblas_float_matmul_op<R3::value, R4::value>;
+ } else {
+ return my_matmul_op<R1, R2, R3::value, R4::value>;
+ }
+ }
+};
+
} // namespace vespalib::tensor::<unnamed>
DenseMatMulFunction::Self::Self(const eval::ValueType &result_type_in,
@@ -197,9 +168,11 @@ DenseMatMulFunction::~DenseMatMulFunction() = default;
eval::InterpretedFunction::Instruction
DenseMatMulFunction::compile_self(const TensorEngine &, Stash &stash) const
{
+ using MyTypify = TypifyValue<eval::TypifyCellType,TypifyBool>;
Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size);
- auto op = my_select(lhs().result_type().cell_type(), rhs().result_type().cell_type(),
- _lhs_common_inner, _rhs_common_inner);
+ auto op = typify_invoke<4,MyTypify,MyGetFun>(
+ lhs().result_type().cell_type(), rhs().result_type().cell_type(),
+ _lhs_common_inner, _rhs_common_inner);
return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self));
}