diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-06-13 05:43:59 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-06-13 12:56:14 +0000 |
commit | 5cb5db5d44fcd4fb271e3b32a7f3f0384e68e497 (patch) | |
tree | 2b59b89d53b39437ef35f9b137d5b46b337e8e92 /eval | |
parent | b88fe841de3fd25b28392d1bb1b0d17f45e209f7 (diff) |
typify_invoke instead of multiple select levels
* follow pattern suggested in matmul PR
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp | 47 |
1 files changed, 16 insertions, 31 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp index a8a896a893a..968308d69c9 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp @@ -76,35 +76,6 @@ void my_cblas_float_xw_product_op(eval::InterpretedFunction::State &state, uint6 state.pop_pop_push(state.stash.create<DenseTensorView>(self.result_type, TypedCells(dst_cells))); } -template <bool common_inner> -struct MyXWProductOp { - template <typename LCT, typename RCT> - static auto invoke() { return my_xw_product_op<LCT,RCT,common_inner>; } -}; - -template <bool common_inner> -eval::InterpretedFunction::op_function my_select2(CellType lct, CellType rct) { - if (lct == rct) { - if (lct == ValueType::CellType::DOUBLE) { - return my_cblas_double_xw_product_op<common_inner>; - } - if (lct == ValueType::CellType::FLOAT) { - return my_cblas_float_xw_product_op<common_inner>; - } - } - using Target = MyXWProductOp<common_inner>; - using MyTypify = eval::TypifyCellType; - return typify_invoke<2,MyTypify,Target>(lct, rct); -} - -eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct, bool common_inner) { - if (common_inner) { - return my_select2<true>(lct, rct); - } else { - return my_select2<false>(lct, rct); - } -} - bool isDenseTensor(const ValueType &type, size_t d) { return (type.is_dense() && (type.dimensions().size() == d)); } @@ -134,6 +105,18 @@ const TensorFunction &createDenseXWProduct(const ValueType &res, const TensorFun common_inner); } +struct MyXWProductOp { + template<typename R1, typename R2, typename R3> static auto invoke() { + if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) { + return my_cblas_double_xw_product_op<R3::value>; + } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) { + return my_cblas_float_xw_product_op<R3::value>; + } else { + return my_xw_product_op<R1, R2, R3::value>; + } + } +}; + } // namespace vespalib::tensor::<unnamed> DenseXWProductFunction::Self::Self(const eval::ValueType &result_type_in, @@ -162,8 +145,10 @@ eval::InterpretedFunction::Instruction DenseXWProductFunction::compile_self(const TensorEngine &, Stash &stash) const { Self &self = stash.create<Self>(result_type(), _vector_size, _result_size); - auto op = my_select(lhs().result_type().cell_type(), - rhs().result_type().cell_type(), _common_inner); + using MyTypify = TypifyValue<eval::TypifyCellType,vespalib::TypifyBool>; + auto op = typify_invoke<3,MyTypify,MyXWProductOp>(lhs().result_type().cell_type(), + rhs().result_type().cell_type(), + _common_inner); return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self)); } |