aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/dense_dot_product_function.cpp
blob: 8395ef6611759bee1f03f7eb99b62b2ac7ad8aca (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "dense_dot_product_function.h"
#include <vespa/eval/eval/inline_operation.h>
#include <vespa/eval/eval/value.h>

namespace vespalib::eval {

using namespace tensor_function;
using namespace operation;

namespace {

template <typename LCT, typename RCT>
void my_dot_product_op(InterpretedFunction::State &state, uint64_t) {
    auto lhs_cells = state.peek(1).cells().typify<LCT>();
    auto rhs_cells = state.peek(0).cells().typify<RCT>();
    double result = DotProduct<LCT,RCT>::apply(lhs_cells.cbegin(), rhs_cells.cbegin(), lhs_cells.size());
    state.pop_pop_push(state.stash.create<DoubleValue>(result));
}

struct MyDotProductOp {
    template <typename LCT, typename RCT>
    static auto invoke() { return my_dot_product_op<LCT,RCT>; }
};

} // namespace <unnamed>

DenseDotProductFunction::DenseDotProductFunction(const TensorFunction &lhs_in,
                                                 const TensorFunction &rhs_in)
    : tensor_function::Op2(ValueType::double_type(), lhs_in, rhs_in)
{
}

InterpretedFunction::Instruction
DenseDotProductFunction::compile_self(const ValueBuilderFactory &, Stash &) const
{
    auto op = typify_invoke<2,TypifyCellType,MyDotProductOp>(lhs().result_type().cell_type(),
                                                             rhs().result_type().cell_type());
    return InterpretedFunction::Instruction(op);
}

bool
DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
{
    return (res.is_double() && lhs.is_dense() && (rhs.dimensions() == lhs.dimensions()));
}

const TensorFunction &
DenseDotProductFunction::optimize(const TensorFunction &expr, Stash &stash)
{
    auto reduce = as<Reduce>(expr);
    if (reduce && (reduce->aggr() == Aggr::SUM)) {
        auto join = as<Join>(reduce->child());
        if (join && (join->function() == Mul::f)) {
            const TensorFunction &lhs = join->lhs();
            const TensorFunction &rhs = join->rhs();
            if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) {
                return stash.create<DenseDotProductFunction>(lhs, rhs);
            }
        }
    }
    return expr;
}

} // namespace