aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
blob: 5880a90a2cd75752ac8dba62c02d05bbc92a30ed (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "mixed_inner_product_function.h"
#include <vespa/eval/eval/inline_operation.h>
#include <vespa/eval/eval/value.h>

namespace vespalib::eval {

using namespace tensor_function;
using namespace operation;

namespace {

struct MixedInnerProductParam {
    ValueType res_type;
    size_t vector_size;
    size_t out_subspace_size;

    MixedInnerProductParam(const ValueType &res_type_in,
                           const ValueType &mix_type,
                           const ValueType &vec_type)
      : res_type(res_type_in),
        vector_size(vec_type.dense_subspace_size()),
        out_subspace_size(res_type.dense_subspace_size())
    {
        assert(vector_size * out_subspace_size == mix_type.dense_subspace_size());
    }
};

template <typename MCT, typename VCT, typename OCT>
void my_mixed_inner_product_op(InterpretedFunction::State &state, uint64_t param_in) {
    const auto &param = unwrap_param<MixedInnerProductParam>(param_in);
    const auto &mixed = state.peek(1);
    const auto &vector = state.peek(0);
    auto m_cells = mixed.cells().typify<MCT>();
    auto v_cells = vector.cells().typify<VCT>();
    const auto &index = mixed.index();
    size_t num_subspaces = index.size();
    size_t num_output_cells = num_subspaces * param.out_subspace_size;
    ArrayRef<OCT> out_cells = state.stash.create_uninitialized_array<OCT>(num_output_cells);
    const MCT *m_cp = m_cells.begin();
    const VCT *v_cp = v_cells.begin();
    using dot_product = DotProduct<MCT,VCT>;
    for (OCT &out : out_cells) {
        out = dot_product::apply(m_cp, v_cp, param.vector_size);
        m_cp += param.vector_size;
    }
    assert(m_cp == m_cells.end());
    state.pop_pop_push(state.stash.create<ValueView>(param.res_type, index, TypedCells(out_cells)));
}
        

struct SelectMixedInnerProduct {
    template <typename MCT, typename VCT, typename OCT>
    static auto invoke() { return my_mixed_inner_product_op<MCT,VCT,OCT>; }
};

} // namespace <unnamed>

MixedInnerProductFunction::MixedInnerProductFunction(const ValueType &res_type_in,
                                                     const TensorFunction &mixed_child,
                                                     const TensorFunction &vector_child)
  : tensor_function::Op2(res_type_in, mixed_child, vector_child)
{
}

InterpretedFunction::Instruction
MixedInnerProductFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
    const auto &mix_type = lhs().result_type();
    const auto &vec_type = rhs().result_type();
    auto &param = stash.create<MixedInnerProductParam>(result_type(), mix_type, vec_type);
    using MyTypify = TypifyValue<TypifyCellType>;
    auto op = typify_invoke<3,MyTypify,SelectMixedInnerProduct>(mix_type.cell_type(),
                                                                vec_type.cell_type(),
                                                                result_type().cell_type());
    return InterpretedFunction::Instruction(op, wrap_param<MixedInnerProductParam>(param));
}

bool
MixedInnerProductFunction::compatible_types(const ValueType &res, const ValueType &mixed, const ValueType &vector)
{
    if (vector.is_dense() && ! res.is_double()) {
        auto dense_dims = vector.nontrivial_indexed_dimensions();
        auto mixed_dims = mixed.nontrivial_indexed_dimensions();
        while (! dense_dims.empty()) {
            if (mixed_dims.empty()) {
                return false;
            }
            const auto &name = dense_dims.back().name;
            if (res.dimension_index(name) != ValueType::Dimension::npos) {
                return false;
            }
            if (name != mixed_dims.back().name) {
                return false;
            }
            dense_dims.pop_back();
            mixed_dims.pop_back();
        }
        while (! mixed_dims.empty()) {
            const auto &name = mixed_dims.back().name;
            if (res.dimension_index(name) == ValueType::Dimension::npos) {
                return false;
            }
            mixed_dims.pop_back();
        }
        return (res.mapped_dimensions() == mixed.mapped_dimensions());
    }
    return false;
}

const TensorFunction &
MixedInnerProductFunction::optimize(const TensorFunction &expr, Stash &stash)
{
    const auto & res_type = expr.result_type();
    auto reduce = as<Reduce>(expr);
    if ((! res_type.is_double()) && reduce && (reduce->aggr() == Aggr::SUM)) {
        auto join = as<Join>(reduce->child());
        if (join && (join->function() == Mul::f)) {
            const TensorFunction &lhs = join->lhs();
            const TensorFunction &rhs = join->rhs();
            if (compatible_types(res_type, lhs.result_type(), rhs.result_type())) {
                return stash.create<MixedInnerProductFunction>(res_type, lhs, rhs);
            }
            if (compatible_types(res_type, rhs.result_type(), lhs.result_type())) {
                return stash.create<MixedInnerProductFunction>(res_type, rhs, lhs);
            }
        }
    }
    return expr;
}

} // namespace