summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-06-12 13:15:36 +0000
committerArne Juul <arnej@verizonmedia.com>2020-06-12 13:50:27 +0000
commit5233b53188a40bfc51f38990173d5d906534499a (patch)
tree30a75ccee0a98aeea2995533bb2e7804581fc6eb /eval
parent015d3f1afd813dd738432c017db0644882dd30de (diff)
use typify_invoke to select matmul implementation
* instead of a chain of select functions, use typify_invoke; use partial specialization on a struct to handle the double*double and float*float cases specially.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp76
1 files changed, 33 insertions, 43 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
index 695e0fddd08..9a43423f13b 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_matmul_function.cpp
@@ -80,47 +80,6 @@ void my_cblas_float_matmul_op(eval::InterpretedFunction::State &state, uint64_t
state.pop_pop_push(state.stash.create<DenseTensorView>(self.result_type, TypedCells(dst_cells)));
}
-template <bool lhs_common_inner, bool rhs_common_inner>
-struct MyMatMulOp {
- template <typename LCT, typename RCT>
- static auto get_fun() { return my_matmul_op<LCT,RCT,lhs_common_inner,rhs_common_inner>; }
-};
-
-template <bool lhs_common_inner, bool rhs_common_inner>
-eval::InterpretedFunction::op_function my_select3(CellType lct, CellType rct)
-{
- if (lct == rct) {
- if (lct == ValueType::CellType::DOUBLE) {
- return my_cblas_double_matmul_op<lhs_common_inner,rhs_common_inner>;
- }
- if (lct == ValueType::CellType::FLOAT) {
- return my_cblas_float_matmul_op<lhs_common_inner,rhs_common_inner>;
- }
- }
- return select_2<MyMatMulOp<lhs_common_inner,rhs_common_inner>>(lct, rct);
-}
-
-template <bool lhs_common_inner>
-eval::InterpretedFunction::op_function my_select2(CellType lct, CellType rct,
- bool rhs_common_inner)
-{
- if (rhs_common_inner) {
- return my_select3<lhs_common_inner,true>(lct, rct);
- } else {
- return my_select3<lhs_common_inner,false>(lct, rct);
- }
-}
-
-eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct,
- bool lhs_common_inner, bool rhs_common_inner)
-{
- if (lhs_common_inner) {
- return my_select2<true>(lct, rct, rhs_common_inner);
- } else {
- return my_select2<false>(lct, rct, rhs_common_inner);
- }
-}
-
bool is_matrix(const ValueType &type) {
return (type.is_dense() && (type.dimensions().size() == 2));
}
@@ -160,6 +119,35 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio
}
}
+template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner>
+struct MyMatMulOp {
+ static auto get_fun() {
+ return my_matmul_op<LCT, RCT, lhs_common_inner, rhs_common_inner>;
+ }
+};
+
+template <bool lhs_common_inner, bool rhs_common_inner>
+struct MyMatMulOp<double, double, lhs_common_inner, rhs_common_inner> {
+ static auto get_fun() {
+ return my_cblas_double_matmul_op<lhs_common_inner, rhs_common_inner>;
+ }
+};
+
+template <bool lhs_common_inner, bool rhs_common_inner>
+struct MyMatMulOp<float, float, lhs_common_inner, rhs_common_inner> {
+ static auto get_fun() {
+ return my_cblas_float_matmul_op<lhs_common_inner, rhs_common_inner>;
+ }
+};
+
+struct MyTarget {
+ template<typename R1, typename R2, typename R3, typename R4>
+ static auto invoke() {
+ using MyOp = MyMatMulOp<R1, R2, R3::value, R4::value>;
+ return MyOp::get_fun();
+ }
+};
+
} // namespace vespalib::tensor::<unnamed>
DenseMatMulFunction::Self::Self(const eval::ValueType &result_type_in,
@@ -197,9 +185,11 @@ DenseMatMulFunction::~DenseMatMulFunction() = default;
eval::InterpretedFunction::Instruction
DenseMatMulFunction::compile_self(const TensorEngine &, Stash &stash) const
{
+ using MyTypify = TypifyValue<eval::TypifyCellType,vespalib::TypifyBool>;
Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size);
- auto op = my_select(lhs().result_type().cell_type(), rhs().result_type().cell_type(),
- _lhs_common_inner, _rhs_common_inner);
+ auto op = typify_invoke<4,MyTypify,MyTarget>(
+ lhs().result_type().cell_type(), rhs().result_type().cell_type(),
+ _lhs_common_inner, _rhs_common_inner);
return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self));
}