aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/sparse_merge_function.cpp
blob: 6dff7440d015746133756e3499056175c0abb962 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "sparse_merge_function.h"
#include "generic_merge.h"
#include <vespa/eval/eval/fast_value.hpp>
#include <vespa/vespalib/util/typify.h>

namespace vespalib::eval {

using namespace tensor_function;
using namespace operation;
using namespace instruction;

namespace {

template <typename CT, bool single_dim, typename Fun>
const Value& my_fast_sparse_merge(const FastAddrMap &a_map, const FastAddrMap &b_map,
                                  const CT *a_cells, const CT *b_cells,
                                  const MergeParam &params,
                                  Stash &stash)
{
    Fun fun(params.function);
    size_t guess_size = a_map.size() + b_map.size();
    auto &result = stash.create<FastValue<CT,true>>(params.res_type, params.num_mapped_dimensions, 1u, guess_size);
    if constexpr (single_dim) {
        string_id cur_label;
        ConstArrayRef<string_id> addr(&cur_label, 1);
        const auto &a_labels = a_map.labels();
        for (size_t i = 0; i < a_labels.size(); ++i) {
            cur_label = a_labels[i];
            result.add_mapping(addr, cur_label.hash());
            result.my_cells.push_back_fast(a_cells[i]);
        }
        const auto &b_labels = b_map.labels();
        for (size_t i = 0; i < b_labels.size(); ++i) {
            cur_label = b_labels[i];
            auto result_subspace = result.my_index.map.lookup_singledim(cur_label);
            if (result_subspace == FastAddrMap::npos()) {
                result.add_mapping(addr, cur_label.hash());
                result.my_cells.push_back_fast(b_cells[i]);
            } else {
                CT *out_cell = result.my_cells.get(result_subspace);
                out_cell[0] = fun(out_cell[0], b_cells[i]);
            }
        }
    } else {
        a_map.each_map_entry([&](auto lhs_subspace, auto hash)
        {
            result.add_mapping(a_map.get_addr(lhs_subspace), hash);
            result.my_cells.push_back_fast(a_cells[lhs_subspace]);
        });
        b_map.each_map_entry([&](auto rhs_subspace, auto hash)
        {
            auto rhs_addr = b_map.get_addr(rhs_subspace);
            auto result_subspace = result.my_index.map.lookup(rhs_addr, hash);
            if (result_subspace == FastAddrMap::npos()) {
                result.add_mapping(rhs_addr, hash);
                result.my_cells.push_back_fast(b_cells[rhs_subspace]);
            } else {
                CT *out_cell = result.my_cells.get(result_subspace);
                out_cell[0] = fun(out_cell[0], b_cells[rhs_subspace]);
            }
        });
    }
    return result;
}

template <typename CT, bool single_dim, typename Fun>
void my_sparse_merge_op(InterpretedFunction::State &state, uint64_t param_in) {
    const auto &param = unwrap_param<MergeParam>(param_in);
    assert(param.dense_subspace_size == 1u);
    const Value &a = state.peek(1);
    const Value &b = state.peek(0);
    const auto &a_idx = a.index();
    const auto &b_idx = b.index();
    if (__builtin_expect(are_fast(a_idx, b_idx), true)) {
        auto a_cells = a.cells().typify<CT>();
        auto b_cells = b.cells().typify<CT>();
        const Value &v = my_fast_sparse_merge<CT,single_dim,Fun>(as_fast(a_idx).map, as_fast(b_idx).map,
                                                                 a_cells.cbegin(), b_cells.cbegin(),
                                                                 param, state.stash);
        state.pop_pop_push(v);
    } else {
        auto up = generic_mixed_merge<CT,CT,CT,Fun>(a, b, param);
        state.pop_pop_push(*state.stash.create<std::unique_ptr<Value>>(std::move(up)));
    }
}

struct SelectSparseMergeOp {
    template <typename R1, typename SINGLE_DIM, typename Fun>
    static auto invoke() {
        using CT = CellValueType<R1::value.cell_type>;        
        return my_sparse_merge_op<CT,SINGLE_DIM::value,Fun>;
    }
};

using MyTypify = TypifyValue<TypifyCellMeta,TypifyBool,operation::TypifyOp2>;

} // namespace <unnamed>

SparseMergeFunction::SparseMergeFunction(const tensor_function::Merge &original)
  : tensor_function::Merge(original.result_type(),
                           original.lhs(),
                           original.rhs(),
                           original.function())
{
    assert(compatible_types(result_type(), lhs().result_type(), rhs().result_type()));
}

InterpretedFunction::Instruction
SparseMergeFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const
{
    const auto &param = stash.create<MergeParam>(result_type(),
                                                 lhs().result_type(), rhs().result_type(),
                                                 function(), factory);
    size_t num_dims = result_type().count_mapped_dimensions();
    auto op = typify_invoke<3,MyTypify,SelectSparseMergeOp>(result_type().cell_meta().limit(),
                                                            num_dims == 1,
                                                            function());
    return InterpretedFunction::Instruction(op, wrap_param<MergeParam>(param));
}

bool
SparseMergeFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
{
    if ((lhs.cell_type() == rhs.cell_type())
        && (lhs.cell_type() == res.cell_type())
        && (lhs.count_mapped_dimensions() > 0)
        && (lhs.dense_subspace_size() == 1))
    {
        assert(res == lhs);
        assert(res == rhs);
        return true;
    }
    return false;
}

const TensorFunction &
SparseMergeFunction::optimize(const TensorFunction &expr, Stash &stash)
{
    if (auto merge = as<Merge>(expr)) {
        const TensorFunction &lhs = merge->lhs();
        const TensorFunction &rhs = merge->rhs();
        if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) {
            return stash.create<SparseMergeFunction>(*merge);
        }
    }
    return expr;
}

} // namespace