diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-03-15 11:35:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-15 11:35:37 +0100 |
commit | 61338b390a89aa9d7c5f9d7c3767e847c5a8a1c8 (patch) | |
tree | 7adf037a290daf88322e025af23b03afbe7ff3f2 /eval | |
parent | ebbb3bc331924854ba1b6f2ee7c871ee0170f00a (diff) | |
parent | ee41dcf75bf47d2686f14c7e256e73b9e1cc09fb (diff) |
Merge pull request #16946 from vespa-engine/arnej/unify-via-cell-meta
use CellMeta::join and reduce to compute result cell type
Diffstat (limited to 'eval')
5 files changed, 73 insertions, 57 deletions
diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index 9036188bad3..4eec639478d 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -124,6 +124,9 @@ struct CellMeta { static constexpr CellMeta reduce(CellType input_cell_type, bool output_is_scalar) { return normalize(input_cell_type, output_is_scalar).decay(); } + constexpr CellMeta reduce(bool output_is_scalar) const { + return CellMeta::reduce(cell_type, output_is_scalar); + } static constexpr CellMeta join(CellMeta a, CellMeta b) { return unify(a, b).decay(); } static constexpr CellMeta merge(CellMeta a, CellMeta b) { return unify(a, b).decay(); } static constexpr CellMeta concat(CellMeta a, CellMeta b) { return unify(a, b); } @@ -133,12 +136,6 @@ struct CellMeta { constexpr CellMeta rename() const { return self(); } }; -template <typename A, typename B> constexpr auto unify_cell_types() { - constexpr CellMeta a(get_cell_type<A>(), false); - constexpr CellMeta b(get_cell_type<B>(), false); - return get_cell_value<CellMeta::unify(a, b).cell_type>(); -} - struct TypifyCellType { template <typename T> using Result = TypifyResultType<T>; template <typename F> static decltype(auto) resolve(CellType value, F &&f) { diff --git a/eval/src/vespa/eval/instruction/dense_matmul_function.cpp b/eval/src/vespa/eval/instruction/dense_matmul_function.cpp index 11ad646d0f5..509a25d28a2 100644 --- a/eval/src/vespa/eval/instruction/dense_matmul_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_matmul_function.cpp @@ -14,9 +14,9 @@ using namespace operation; namespace { -template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner> -double my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t common_size, size_t rhs_size) { - double result = 0.0; +template <typename LCT, typename RCT, typename OCT, bool lhs_common_inner, bool rhs_common_inner> +OCT my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t common_size, size_t rhs_size) { + OCT result = 0.0; for (size_t i = 0; i < common_size; ++i) { result += ((*lhs) * (*rhs)); lhs += (lhs_common_inner ? 1 : lhs_size); @@ -25,10 +25,9 @@ double my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t co return result; } -template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner> +template <typename LCT, typename RCT, typename OCT, bool lhs_common_inner, bool rhs_common_inner> void my_matmul_op(InterpretedFunction::State &state, uint64_t param) { const DenseMatMulFunction::Self &self = unwrap_param<DenseMatMulFunction::Self>(param); - using OCT = decltype(unify_cell_types<LCT,RCT>()); auto lhs_cells = state.peek(1).cells().typify<LCT>(); auto rhs_cells = state.peek(0).cells().typify<RCT>(); auto dst_cells = state.stash.create_uninitialized_array<OCT>(self.lhs_size * self.rhs_size); @@ -37,7 +36,8 @@ void my_matmul_op(InterpretedFunction::State &state, uint64_t param) { for (size_t i = 0; i < self.lhs_size; ++i) { const RCT *rhs = rhs_cells.cbegin(); for (size_t j = 0; j < self.rhs_size; ++j) { - *dst++ = my_dot_product<LCT,RCT,lhs_common_inner,rhs_common_inner>(lhs, rhs, self.lhs_size, self.common_size, self.rhs_size); + *dst++ = my_dot_product<LCT,RCT,OCT,lhs_common_inner,rhs_common_inner>(lhs, rhs, + self.lhs_size, self.common_size, self.rhs_size); rhs += (rhs_common_inner ? self.common_size : 1); } lhs += (lhs_common_inner ? self.common_size : 1); @@ -112,14 +112,18 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio } } -struct MyGetFun { - template<typename R1, typename R2, typename R3, typename R4> static auto invoke() { - if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) { - return my_cblas_double_matmul_op<R3::value, R4::value>; - } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) { - return my_cblas_float_matmul_op<R3::value, R4::value>; +struct SelectDenseMatmul { + template<typename LCM, typename RCM, typename LhsCommonInner, typename RhsCommonInner> static auto invoke() { + constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(false); + using LCT = CellValueType<LCM::value.cell_type>; + using RCT = CellValueType<RCM::value.cell_type>; + using OCT = CellValueType<ocm.cell_type>; + if (std::is_same_v<LCT,double> && std::is_same_v<RCT,double>) { + return my_cblas_double_matmul_op<LhsCommonInner::value, RhsCommonInner::value>; + } else if (std::is_same_v<LCT,float> && std::is_same_v<RCT,float>) { + return my_cblas_float_matmul_op<LhsCommonInner::value, RhsCommonInner::value>; } else { - return my_matmul_op<R1, R2, R3::value, R4::value>; + return my_matmul_op<LCT, RCT, OCT, LhsCommonInner::value, RhsCommonInner::value>; } } }; @@ -161,11 +165,12 @@ DenseMatMulFunction::~DenseMatMulFunction() = default; InterpretedFunction::Instruction DenseMatMulFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const { - using MyTypify = TypifyValue<TypifyCellType,TypifyBool>; + using MyTypify = TypifyValue<TypifyCellMeta,TypifyBool>; Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size); - auto op = typify_invoke<4,MyTypify,MyGetFun>( - lhs().result_type().cell_type(), rhs().result_type().cell_type(), - _lhs_common_inner, _rhs_common_inner); + auto op = typify_invoke<4,MyTypify,SelectDenseMatmul>( + lhs().result_type().cell_meta().not_scalar(), + rhs().result_type().cell_meta().not_scalar(), + _lhs_common_inner, _rhs_common_inner); return InterpretedFunction::Instruction(op, wrap_param<DenseMatMulFunction::Self>(self)); } diff --git a/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp b/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp index 55c6760a391..2815da40d0a 100644 --- a/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp @@ -32,11 +32,10 @@ struct ExpandParams { : result_type(result_type_in), result_size(result_size_in), function(function_in) {} }; -template <typename LCT, typename RCT, typename Fun, bool rhs_inner> +template <typename LCT, typename RCT, typename DCT, typename Fun, bool rhs_inner> void my_simple_expand_op(State &state, uint64_t param) { using ICT = typename std::conditional<rhs_inner,RCT,LCT>::type; using OCT = typename std::conditional<rhs_inner,LCT,RCT>::type; - using DCT = decltype(unify_cell_types<LCT,RCT>()); using OP = typename std::conditional<rhs_inner,SwapArgs2<Fun>,Fun>::type; const ExpandParams ¶ms = unwrap_param<ExpandParams>(param); OP my_op(params.function); @@ -53,13 +52,18 @@ void my_simple_expand_op(State &state, uint64_t param) { //----------------------------------------------------------------------------- -struct MyGetFun { - template <typename R1, typename R2, typename R3, typename R4> static auto invoke() { - return my_simple_expand_op<R1, R2, R3, R4::value>; +struct SelectDenseSimpleExpand { + template<typename LCM, typename RCM, typename Fun, typename RhsInner> + static auto invoke() { + constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value); + using LCT = CellValueType<LCM::value.cell_type>; + using RCT = CellValueType<RCM::value.cell_type>; + using OCT = CellValueType<ocm.cell_type>; + return my_simple_expand_op<LCT, RCT, OCT, Fun, RhsInner::value>; } }; -using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool>; +using MyTypify = TypifyValue<TypifyCellMeta,TypifyOp2,TypifyBool>; //----------------------------------------------------------------------------- @@ -98,9 +102,9 @@ DenseSimpleExpandFunction::compile_self(const ValueBuilderFactory &, Stash &stas { size_t result_size = result_type().dense_subspace_size(); const ExpandParams ¶ms = stash.create<ExpandParams>(result_type(), result_size, function()); - auto op = typify_invoke<4,MyTypify,MyGetFun>(lhs().result_type().cell_type(), - rhs().result_type().cell_type(), - function(), (_inner == Inner::RHS)); + auto op = typify_invoke<4,MyTypify,SelectDenseSimpleExpand>(lhs().result_type().cell_meta().not_scalar(), + rhs().result_type().cell_meta().not_scalar(), + function(), (_inner == Inner::RHS)); return Instruction(op, wrap_param<ExpandParams>(params)); } diff --git a/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp b/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp index b68a3a87ef1..371a5767382 100644 --- a/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp @@ -15,9 +15,9 @@ using namespace operation; namespace { -template <typename LCT, typename RCT, bool common_inner> -double my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t result_size) { - double result = 0.0; +template <typename LCT, typename RCT, typename OCT, bool common_inner> +OCT my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t result_size) { + OCT result = 0.0; for (size_t i = 0; i < vector_size; ++i) { result += ((*lhs) * (*rhs)); ++lhs; @@ -26,17 +26,16 @@ double my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t return result; } -template <typename LCT, typename RCT, bool common_inner> +template <typename LCT, typename RCT, typename OCT, bool common_inner> void my_xw_product_op(InterpretedFunction::State &state, uint64_t param) { const DenseXWProductFunction::Self &self = unwrap_param<DenseXWProductFunction::Self>(param); - using OCT = decltype(unify_cell_types<LCT,RCT>()); auto vector_cells = state.peek(1).cells().typify<LCT>(); auto matrix_cells = state.peek(0).cells().typify<RCT>(); auto dst_cells = state.stash.create_uninitialized_array<OCT>(self.result_size); OCT *dst = dst_cells.begin(); const RCT *matrix = matrix_cells.cbegin(); for (size_t i = 0; i < self.result_size; ++i) { - *dst++ = my_dot_product<LCT,RCT,common_inner>(vector_cells.cbegin(), matrix, self.vector_size, self.result_size); + *dst++ = my_dot_product<LCT,RCT,OCT,common_inner>(vector_cells.cbegin(), matrix, self.vector_size, self.result_size); matrix += (common_inner ? self.vector_size : 1); } state.pop_pop_push(state.stash.create<DenseValueView>(self.result_type, TypedCells(dst_cells))); @@ -100,13 +99,19 @@ const TensorFunction &createDenseXWProduct(const ValueType &res, const TensorFun } struct MyXWProductOp { - template<typename R1, typename R2, typename R3> static auto invoke() { - if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) { - return my_cblas_double_xw_product_op<R3::value>; - } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) { - return my_cblas_float_xw_product_op<R3::value>; + template<typename LCM, typename RCM, typename CommonInner> static auto invoke() { + constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(false); + using LCT = CellValueType<LCM::value.cell_type>; + using RCT = CellValueType<RCM::value.cell_type>; + using OCT = CellValueType<ocm.cell_type>; + if (std::is_same_v<LCT,double> && std::is_same_v<RCT,double>) { + assert((std::is_same_v<OCT,double>)); + return my_cblas_double_xw_product_op<CommonInner::value>; + } else if (std::is_same_v<LCT,float> && std::is_same_v<RCT,float>) { + assert((std::is_same_v<OCT,float>)); + return my_cblas_float_xw_product_op<CommonInner::value>; } else { - return my_xw_product_op<R1, R2, R3::value>; + return my_xw_product_op<LCT, RCT, OCT, CommonInner::value>; } } }; @@ -139,9 +144,10 @@ InterpretedFunction::Instruction DenseXWProductFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const { Self &self = stash.create<Self>(result_type(), _vector_size, _result_size); - using MyTypify = TypifyValue<TypifyCellType,vespalib::TypifyBool>; - auto op = typify_invoke<3,MyTypify,MyXWProductOp>(lhs().result_type().cell_type(), - rhs().result_type().cell_type(), + assert(self.result_type.cell_meta().is_scalar == false); + using MyTypify = TypifyValue<TypifyCellMeta,vespalib::TypifyBool>; + auto op = typify_invoke<3,MyTypify,MyXWProductOp>(lhs().result_type().cell_meta().not_scalar(), + rhs().result_type().cell_meta().not_scalar(), _common_inner); return InterpretedFunction::Instruction(op, wrap_param<DenseXWProductFunction::Self>(self)); } diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp index d487ab42d26..21c6f945609 100644 --- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp @@ -55,11 +55,10 @@ ArrayRef<OCT> make_dst_cells(ConstArrayRef<PCT> pri_cells, Stash &stash) { } } -template <typename LCT, typename RCT, typename Fun, bool swap, Overlap overlap, bool pri_mut> +template <typename LCT, typename RCT, typename OCT, typename Fun, bool swap, Overlap overlap, bool pri_mut> void my_simple_join_op(State &state, uint64_t param) { using PCT = typename std::conditional<swap,RCT,LCT>::type; using SCT = typename std::conditional<swap,LCT,RCT>::type; - using OCT = decltype(unify_cell_types<LCT,RCT>()); using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type; const JoinParams ¶ms = unwrap_param<JoinParams>(param); OP my_op(params.function); @@ -94,13 +93,18 @@ void my_simple_join_op(State &state, uint64_t param) { //----------------------------------------------------------------------------- -struct MyGetFun { - template <typename R1, typename R2, typename R3, typename R4, typename R5, typename R6> static auto invoke() { - return my_simple_join_op<R1, R2, R3, R4::value, R5::value, R6::value>; +struct SelectMixedSimpleJoin { + template<typename LCM, typename RCM, typename R3, typename R4, typename R5, typename R6> + static auto invoke() { + constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value); + using LCT = CellValueType<LCM::value.cell_type>; + using RCT = CellValueType<RCM::value.cell_type>; + using OCT = CellValueType<ocm.cell_type>; + return my_simple_join_op<LCT, RCT, OCT, R3, R4::value, R5::value, R6::value>; } }; -using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool,TypifyOverlap>; +using MyTypify = TypifyValue<TypifyCellMeta,TypifyOp2,TypifyBool,TypifyOverlap>; //----------------------------------------------------------------------------- @@ -197,10 +201,10 @@ Instruction MixedSimpleJoinFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const { const JoinParams ¶ms = stash.create<JoinParams>(result_type(), factor(), function()); - auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(), - rhs().result_type().cell_type(), - function(), (_primary == Primary::RHS), - _overlap, primary_is_mutable()); + auto op = typify_invoke<6,MyTypify,SelectMixedSimpleJoin>(lhs().result_type().cell_meta().not_scalar(), + rhs().result_type().cell_meta().not_scalar(), + function(), (_primary == Primary::RHS), + _overlap, primary_is_mutable()); return Instruction(op, wrap_param<JoinParams>(params)); } |