aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-01-20 14:27:07 +0000
committerArne Juul <arnej@verizonmedia.com>2021-01-20 14:27:07 +0000
commit51132b40f12e2a764639d8e7bbf0ae8fe75921fc (patch)
tree43d8d639bf6f1daa419400b4d85dbe85fbe30728 /eval
parentd9de256ebd110b89cd3a0fd379a9993da2a1c753 (diff)
rewrite to allow only-float path
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp32
1 files changed, 19 insertions, 13 deletions
diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
index 7a44a8902c9..c8a4df2b82d 100644
--- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
@@ -13,23 +13,29 @@ using namespace operation;
namespace {
template <typename LCT, typename RCT>
-double my_dot_product(const LCT * lhs, const RCT * rhs, size_t count) {
- double result = 0.0;
- for (size_t i = 0; i < count; ++i) {
- result += lhs[i] * rhs[i];
+struct MyDotProduct {
+ static double apply(const LCT * lhs, const RCT * rhs, size_t count) {
+ double result = 0.0;
+ for (size_t i = 0; i < count; ++i) {
+ result += lhs[i] * rhs[i];
+ }
+ return result;
}
- return result;
-}
+};
template <>
-double my_dot_product<double,double>(const double * lhs, const double * rhs, size_t count) {
- return cblas_ddot(count, lhs, 1, rhs, 1);
-}
+struct MyDotProduct<double,double> {
+ static double apply(const double * lhs, const double * rhs, size_t count) {
+ return cblas_ddot(count, lhs, 1, rhs, 1);
+ }
+};
template <>
-double my_dot_product<float,float>(const float * lhs, const float * rhs, size_t count) {
- return cblas_sdot(count, lhs, 1, rhs, 1);
-}
+struct MyDotProduct<float,float> {
+ static float apply(const float * lhs, const float * rhs, size_t count) {
+ return cblas_sdot(count, lhs, 1, rhs, 1);
+ }
+};
struct MixedInnerProductParam {
ValueType res_type;
@@ -61,7 +67,7 @@ void my_mixed_inner_product_op(InterpretedFunction::State &state, uint64_t param
const MCT *m_cp = m_cells.begin();
const VCT *v_cp = v_cells.begin();
for (OCT &out : out_cells) {
- out = my_dot_product(m_cp, v_cp, param.vector_size);
+ out = MyDotProduct<MCT,VCT>::apply(m_cp, v_cp, param.vector_size);
m_cp += param.vector_size;
}
assert(m_cp == m_cells.end());