aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp')
-rw-r--r--eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp31
1 files changed, 3 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
index 248f909fcf5..5880a90a2cd 100644
--- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
@@ -1,9 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "mixed_inner_product_function.h"
-#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
#include <vespa/eval/eval/value.h>
-#include <cblas.h>
namespace vespalib::eval {
@@ -12,31 +11,6 @@ using namespace operation;
namespace {
-template <typename LCT, typename RCT>
-struct MyDotProduct {
- static double apply(const LCT * lhs, const RCT * rhs, size_t count) {
- double result = 0.0;
- for (size_t i = 0; i < count; ++i) {
- result += lhs[i] * rhs[i];
- }
- return result;
- }
-};
-
-template <>
-struct MyDotProduct<double,double> {
- static double apply(const double * lhs, const double * rhs, size_t count) {
- return cblas_ddot(count, lhs, 1, rhs, 1);
- }
-};
-
-template <>
-struct MyDotProduct<float,float> {
- static float apply(const float * lhs, const float * rhs, size_t count) {
- return cblas_sdot(count, lhs, 1, rhs, 1);
- }
-};
-
struct MixedInnerProductParam {
ValueType res_type;
size_t vector_size;
@@ -66,8 +40,9 @@ void my_mixed_inner_product_op(InterpretedFunction::State &state, uint64_t param
ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(num_output_cells);
const MCT *m_cp = m_cells.begin();
const VCT *v_cp = v_cells.begin();
+ using dot_product = DotProduct<MCT,VCT>;
for (OCT &out : out_cells) {
- out = MyDotProduct<MCT,VCT>::apply(m_cp, v_cp, param.vector_size);
+ out = dot_product::apply(m_cp, v_cp, param.vector_size);
m_cp += param.vector_size;
}
assert(m_cp == m_cells.end());