diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-01-20 14:27:07 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-01-20 14:27:07 +0000 |
commit | 51132b40f12e2a764639d8e7bbf0ae8fe75921fc (patch) | |
tree | 43d8d639bf6f1daa419400b4d85dbe85fbe30728 | |
parent | d9de256ebd110b89cd3a0fd379a9993da2a1c753 (diff) |
rewrite to allow only-float path
-rw-r--r-- | eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp index 7a44a8902c9..c8a4df2b82d 100644 --- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp @@ -13,23 +13,29 @@ using namespace operation; namespace { template <typename LCT, typename RCT> -double my_dot_product(const LCT * lhs, const RCT * rhs, size_t count) { - double result = 0.0; - for (size_t i = 0; i < count; ++i) { - result += lhs[i] * rhs[i]; +struct MyDotProduct { + static double apply(const LCT * lhs, const RCT * rhs, size_t count) { + double result = 0.0; + for (size_t i = 0; i < count; ++i) { + result += lhs[i] * rhs[i]; + } + return result; } - return result; -} +}; template <> -double my_dot_product<double,double>(const double * lhs, const double * rhs, size_t count) { - return cblas_ddot(count, lhs, 1, rhs, 1); -} +struct MyDotProduct<double,double> { + static double apply(const double * lhs, const double * rhs, size_t count) { + return cblas_ddot(count, lhs, 1, rhs, 1); + } +}; template <> -double my_dot_product<float,float>(const float * lhs, const float * rhs, size_t count) { - return cblas_sdot(count, lhs, 1, rhs, 1); -} +struct MyDotProduct<float,float> { + static float apply(const float * lhs, const float * rhs, size_t count) { + return cblas_sdot(count, lhs, 1, rhs, 1); + } +}; struct MixedInnerProductParam { ValueType res_type; @@ -61,7 +67,7 @@ void my_mixed_inner_product_op(InterpretedFunction::State &state, uint64_t param const MCT *m_cp = m_cells.begin(); const VCT *v_cp = v_cells.begin(); for (OCT &out : out_cells) { - out = my_dot_product(m_cp, v_cp, param.vector_size); + out = MyDotProduct<MCT,VCT>::apply(m_cp, v_cp, param.vector_size); m_cp += param.vector_size; } assert(m_cp == m_cells.end()); |