aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/vespa/eval/instruction/dense_dot_product_function.cpp')
-rw-r--r--eval/src/vespa/eval/instruction/dense_dot_product_function.cpp40
1 files changed, 4 insertions, 36 deletions
diff --git a/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp b/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp
index a2048707685..de9e029f377 100644
--- a/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp
+++ b/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp
@@ -1,9 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "dense_dot_product_function.h"
-#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
#include <vespa/eval/eval/value.h>
-#include <cblas.h>
namespace vespalib::eval {
@@ -16,26 +15,7 @@ template <typename LCT, typename RCT>
void my_dot_product_op(InterpretedFunction::State &state, uint64_t) {
auto lhs_cells = state.peek(1).cells().typify<LCT>();
auto rhs_cells = state.peek(0).cells().typify<RCT>();
- double result = 0.0;
- const LCT *lhs = lhs_cells.cbegin();
- const RCT *rhs = rhs_cells.cbegin();
- for (size_t i = 0; i < lhs_cells.size(); ++i) {
- result += ((*lhs++) * (*rhs++));
- }
- state.pop_pop_push(state.stash.create<DoubleValue>(result));
-}
-
-void my_cblas_double_dot_product_op(InterpretedFunction::State &state, uint64_t) {
- auto lhs_cells = state.peek(1).cells().typify<double>();
- auto rhs_cells = state.peek(0).cells().typify<double>();
- double result = cblas_ddot(lhs_cells.size(), lhs_cells.cbegin(), 1, rhs_cells.cbegin(), 1);
- state.pop_pop_push(state.stash.create<DoubleValue>(result));
-}
-
-void my_cblas_float_dot_product_op(InterpretedFunction::State &state, uint64_t) {
- auto lhs_cells = state.peek(1).cells().typify<float>();
- auto rhs_cells = state.peek(0).cells().typify<float>();
- double result = cblas_sdot(lhs_cells.size(), lhs_cells.cbegin(), 1, rhs_cells.cbegin(), 1);
+ double result = DotProduct<LCT,RCT>::apply(lhs_cells.cbegin(), rhs_cells.cbegin(), lhs_cells.size());
state.pop_pop_push(state.stash.create<DoubleValue>(result));
}
@@ -44,19 +24,6 @@ struct MyDotProductOp {
static auto invoke() { return my_dot_product_op<LCT,RCT>; }
};
-InterpretedFunction::op_function my_select(CellType lct, CellType rct) {
- if (lct == rct) {
- if (lct == CellType::DOUBLE) {
- return my_cblas_double_dot_product_op;
- }
- if (lct == CellType::FLOAT) {
- return my_cblas_float_dot_product_op;
- }
- }
- using MyTypify = TypifyCellType;
- return typify_invoke<2,MyTypify,MyDotProductOp>(lct, rct);
-}
-
} // namespace <unnamed>
DenseDotProductFunction::DenseDotProductFunction(const TensorFunction &lhs_in,
@@ -68,7 +35,8 @@ DenseDotProductFunction::DenseDotProductFunction(const TensorFunction &lhs_in,
InterpretedFunction::Instruction
DenseDotProductFunction::compile_self(const ValueBuilderFactory &, Stash &) const
{
- auto op = my_select(lhs().result_type().cell_type(), rhs().result_type().cell_type());
+ auto op = typify_invoke<2,TypifyCellType,MyDotProductOp>(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type());
return InterpretedFunction::Instruction(op);
}