aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/grouping/collect.cpp
blob: 9085b08d0f2e07f2ba27fe30d3c605a75bdb5619 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "collect.h"
#include <vespa/vespalib/util/array.hpp>
#include <cassert>

using namespace search::expression;
using namespace search::aggregation;

namespace search::grouping {

Collect::ResultAccessor::ResultAccessor(const AggregationResult & aggregator, size_t offset) :
    _bluePrint(&aggregator),
    _aggregator(_bluePrint->clone()),
    _offset(offset)
{
}

void Collect::ResultAccessor::create(uint8_t * base)
{
    _aggregator->getResult().create(base+_offset);
    _bluePrint->getResult()->encode(base+_offset);
}

Collect::Collect(const Group & gp) :
    _aggregatorSize(0),
    _aggregator(),
    _aggrBacking()
{
    _aggregator.reserve(gp.getAggrSize());
    for (size_t i(0); i < gp.getAggrSize(); i++) {
        ResultAccessor accessor(const_cast<AggregationResult &>(gp.getAggregationResult(i)), _aggregatorSize);
        _aggregator.push_back(accessor);
        assert(accessor.getRawByteSize() > 0);
        _aggregatorSize += accessor.getRawByteSize();
    }
    _sortInfo.resize(gp.getOrderBySize());
    for(size_t i(0); i < _sortInfo.size(); i++) {
        const uint32_t index = std::abs(gp.getOrderBy(i)) - 1;
        const uint32_t z(gp.getExpr(index));
        _sortInfo[i] = SortInfo(z, gp.getOrderBy(i));
    }
}

Collect::~Collect()
{
    if (_aggregatorSize > 0) {
        assert((_aggrBacking.size() % _aggregatorSize) == 0);
        for (size_t i(0), m(_aggrBacking.size()/_aggregatorSize); i < m; i++) {
            uint8_t * base(&_aggrBacking[ i * _aggregatorSize]);
            for (size_t j(0), k(_aggregator.size()); j < k; j++) {
                ResultAccessor & r = _aggregator[j];
                r.destroy(base);
            }
        }
    }
}

void
Collect::getCollectors(GroupRef ref, Group & g) const
{
    size_t offset(getAggrBase(ref));
    if (offset < _aggrBacking.size()) {
        const uint8_t * base(&_aggrBacking[offset]);
        for (size_t i(0), m(_aggregator.size()); i < m; i++) {
            const ResultAccessor & r = _aggregator[i];
            r.getResult(g.getAggregationResult(i).getResult(), base);
            g.getAggregationResult(i).postMerge();
        }
    }
}

void
Collect::collect(GroupRef gr, uint32_t docId, double rank)
{
    uint8_t * base(&_aggrBacking[getAggrBase(gr)]);
    for (size_t i(0), m(_aggregator.size()); i < m; i++) {
        _aggregator[i].aggregate(base, docId, rank);
    }
}

void
Collect::createCollectors(GroupRef gr)
{
    size_t offset(getAggrBase(gr));
    if (offset == _aggrBacking.size()) {
        _aggrBacking.resize(getAggrBase(GroupRef(gr.getRef() + 1)));
        uint8_t * base(&_aggrBacking[offset]);
        for (size_t i(0), m(_aggregator.size()); i < m; i++) {
            ResultAccessor & r = _aggregator[i];
            r.create(base);
        }
    }
}

void
Collect::preFill(GroupRef gr, const Group & g)
{
    if (gr.valid()) {
        size_t offset(getAggrBase(gr));
        uint8_t * base(&_aggrBacking[offset]);
        for (size_t i(0), m(_aggregator.size()); i < m; i++) {
            ResultAccessor & r = _aggregator[i];
            r.setResult(*g.getAggregationResult(i).getResult(), base);
        }
    }
}

}

// this function was added by ../../forcelink.sh
void forcelink_file_searchlib_grouping_collect() {}