diff options
Diffstat (limited to 'eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp')
-rw-r--r-- | eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp | 31 |
1 files changed, 3 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp index 248f909fcf5..5880a90a2cd 100644 --- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp @@ -1,9 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "mixed_inner_product_function.h" -#include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> #include <vespa/eval/eval/value.h> -#include <cblas.h> namespace vespalib::eval { @@ -12,31 +11,6 @@ using namespace operation; namespace { -template <typename LCT, typename RCT> -struct MyDotProduct { - static double apply(const LCT * lhs, const RCT * rhs, size_t count) { - double result = 0.0; - for (size_t i = 0; i < count; ++i) { - result += lhs[i] * rhs[i]; - } - return result; - } -}; - -template <> -struct MyDotProduct<double,double> { - static double apply(const double * lhs, const double * rhs, size_t count) { - return cblas_ddot(count, lhs, 1, rhs, 1); - } -}; - -template <> -struct MyDotProduct<float,float> { - static float apply(const float * lhs, const float * rhs, size_t count) { - return cblas_sdot(count, lhs, 1, rhs, 1); - } -}; - struct MixedInnerProductParam { ValueType res_type; size_t vector_size; @@ -66,8 +40,9 @@ void my_mixed_inner_product_op(InterpretedFunction::State &state, uint64_t param ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(num_output_cells); const MCT *m_cp = m_cells.begin(); const VCT *v_cp = v_cells.begin(); + using dot_product = DotProduct<MCT,VCT>; for (OCT &out : out_cells) { - out = MyDotProduct<MCT,VCT>::apply(m_cp, v_cp, param.vector_size); + out = dot_product::apply(m_cp, v_cp, param.vector_size); m_cp += param.vector_size; } assert(m_cp == m_cells.end()); |