diff options
Diffstat (limited to 'eval/src/vespa/eval/instruction/dense_dot_product_function.cpp')
-rw-r--r-- | eval/src/vespa/eval/instruction/dense_dot_product_function.cpp | 40 |
1 files changed, 4 insertions, 36 deletions
diff --git a/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp b/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp index a2048707685..de9e029f377 100644 --- a/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp @@ -1,9 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "dense_dot_product_function.h" -#include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> #include <vespa/eval/eval/value.h> -#include <cblas.h> namespace vespalib::eval { @@ -16,26 +15,7 @@ template <typename LCT, typename RCT> void my_dot_product_op(InterpretedFunction::State &state, uint64_t) { auto lhs_cells = state.peek(1).cells().typify<LCT>(); auto rhs_cells = state.peek(0).cells().typify<RCT>(); - double result = 0.0; - const LCT *lhs = lhs_cells.cbegin(); - const RCT *rhs = rhs_cells.cbegin(); - for (size_t i = 0; i < lhs_cells.size(); ++i) { - result += ((*lhs++) * (*rhs++)); - } - state.pop_pop_push(state.stash.create<DoubleValue>(result)); -} - -void my_cblas_double_dot_product_op(InterpretedFunction::State &state, uint64_t) { - auto lhs_cells = state.peek(1).cells().typify<double>(); - auto rhs_cells = state.peek(0).cells().typify<double>(); - double result = cblas_ddot(lhs_cells.size(), lhs_cells.cbegin(), 1, rhs_cells.cbegin(), 1); - state.pop_pop_push(state.stash.create<DoubleValue>(result)); -} - -void my_cblas_float_dot_product_op(InterpretedFunction::State &state, uint64_t) { - auto lhs_cells = state.peek(1).cells().typify<float>(); - auto rhs_cells = state.peek(0).cells().typify<float>(); - double result = cblas_sdot(lhs_cells.size(), lhs_cells.cbegin(), 1, rhs_cells.cbegin(), 1); + double result = DotProduct<LCT,RCT>::apply(lhs_cells.cbegin(), rhs_cells.cbegin(), lhs_cells.size()); state.pop_pop_push(state.stash.create<DoubleValue>(result)); } @@ -44,19 +24,6 @@ struct MyDotProductOp { static auto invoke() { return my_dot_product_op<LCT,RCT>; } }; -InterpretedFunction::op_function my_select(CellType lct, CellType rct) { - if (lct == rct) { - if (lct == CellType::DOUBLE) { - return my_cblas_double_dot_product_op; - } - if (lct == CellType::FLOAT) { - return my_cblas_float_dot_product_op; - } - } - using MyTypify = TypifyCellType; - return typify_invoke<2,MyTypify,MyDotProductOp>(lct, rct); -} - } // namespace <unnamed> DenseDotProductFunction::DenseDotProductFunction(const TensorFunction &lhs_in, @@ -68,7 +35,8 @@ DenseDotProductFunction::DenseDotProductFunction(const TensorFunction &lhs_in, InterpretedFunction::Instruction DenseDotProductFunction::compile_self(const ValueBuilderFactory &, Stash &) const { - auto op = my_select(lhs().result_type().cell_type(), rhs().result_type().cell_type()); + auto op = typify_invoke<2,TypifyCellType,MyDotProductOp>(lhs().result_type().cell_type(), + rhs().result_type().cell_type()); return InterpretedFunction::Instruction(op); } |