aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/aggregation/HitsAggregationResult.java
blob: ea39dde92e156ccb388e9d4d87c3f49f99d84e9b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.aggregation;

import com.yahoo.searchlib.expression.FloatResultNode;
import com.yahoo.searchlib.expression.ResultNode;
import com.yahoo.text.Utf8;
import com.yahoo.vespa.objects.*;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

/**
 * This is an aggregated result holding the top n hits for a single group.
 *
 * @author havardpe
 * @author baldersheim
 * @author Simon Thoresen Hult
 */
public class HitsAggregationResult extends AggregationResult {

    public static final int classId = registerClass(0x4000 + 87, HitsAggregationResult.class);
    private String summaryClass = "default";
    private int maxHits = -1;
    private List<Hit> hits = new ArrayList<>();

    /**
     * Constructs an empty result node.
     */
    public HitsAggregationResult() {
        // empty
    }

    /**
     * Create a hits aggregation result that will collect the given number of hits
     *
     * @param maxHits maximum number of hits to collect
     */
    public HitsAggregationResult(int maxHits) {
        this.maxHits = maxHits;
    }

    /**
     * Create a hits aggregation result that will collect the given number of hits of the summaryClass asked.
     *
     * @param maxHits      maximum number of hits to collect
     * @param summaryClass SummaryClass to use for hits to collect
     */
    public HitsAggregationResult(int maxHits, String summaryClass) {
        this.summaryClass = summaryClass;
        this.maxHits = maxHits;
    }

    /**
     * Obtain the summary class used to collect the hits.
     *
     * @return The summary class id.
     */
    public String getSummaryClass() {
        return summaryClass;
    }

    /**
     * Obtain the maximum number of hits to collect.
     *
     * @return Max number of hits to collect.
     */
    public int getMaxHits() {
        return maxHits;
    }

    /**
     * Sets the summary class of hits to collect.
     *
     * @param summaryClass the summary class to collect.
     * @return this, to allow chaining.
     */
    public HitsAggregationResult setSummaryClass(String summaryClass) {
        this.summaryClass = summaryClass;
        return this;
    }

    /**
     * Sets the maximum number of hits to collect.
     *
     * @param maxHits the number of hits to collect.
     * @return this, to allow chaining.
     */
    public HitsAggregationResult setMaxHits(int maxHits) {
        this.maxHits = maxHits;
        return this;
    }

    /**
     * Obtain the hits collected by this aggregation result
     *
     * @return collected hits
     */
    public List<Hit> getHits() {
        return hits;
    }

    /**
     * Adds a hit to this aggregation result
     *
     * @param h the hit
     * @return this object
     */
    public HitsAggregationResult addHit(Hit h) {
        hits.add(h);
        return this;
    }

    @Override
    public ResultNode getRank() {
        if (hits.isEmpty()) {
            return new FloatResultNode(0);
        }
        return new FloatResultNode(hits.get(0).getRank());
    }

    @Override
    protected int onGetClassId() {
        return classId;
    }

    @Override
    protected void onSerialize(Serializer buf) {
        super.onSerialize(buf);
        byte[] raw = Utf8.toBytes(summaryClass);
        buf.putInt(null, raw.length);
        buf.put(null, raw);

        buf.putInt(null, maxHits);
        int numHits = hits.size();
        buf.putInt(null, numHits);
        for (Hit h : hits) {
            serializeOptional(buf, h);
        }
    }

    @Override
    protected void onDeserialize(Deserializer buf) {
        super.onDeserialize(buf);
        summaryClass = getUtf8(buf);
        maxHits = buf.getInt(null);
        int numHits = buf.getInt(null);
        for (int i = 0; i < numHits; i++) {
            Hit h = (Hit)deserializeOptional(buf);
            hits.add(h);
        }
    }

    @Override
    protected void onMerge(AggregationResult result) {
        hits.addAll(((HitsAggregationResult)result).hits);
    }

    @Override
    public void postMerge() {
        hits.sort((lhs, rhs) -> -Double.compare(lhs.getRank(), rhs.getRank()));
        if ((maxHits >= 0) && (hits.size() > maxHits)) {
            hits = hits.subList(0, maxHits);
        }
    }

    @Override
    protected boolean equalsAggregation(AggregationResult obj) {
        HitsAggregationResult rhs = (HitsAggregationResult)obj;
        if ( ! summaryClass.equals(rhs.summaryClass)) return false;
        if (maxHits != rhs.maxHits) return false;
        if ( ! hits.equals(rhs.hits)) return false;
        return true;
    }

    @Override
    public int hashCode() {
        return super.hashCode() + summaryClass.hashCode() + maxHits + hits.hashCode();
    }

    @Override
    public HitsAggregationResult clone() {
        HitsAggregationResult obj = (HitsAggregationResult)super.clone();
        obj.summaryClass = summaryClass;
        obj.maxHits = maxHits;
        obj.hits = new ArrayList<Hit>();
        for (Hit hit : hits) {
            obj.hits.add((Hit)hit.clone());
        }
        return obj;
    }

    @Override
    public void visitMembers(ObjectVisitor visitor) {
        super.visitMembers(visitor);
        visitor.visit("summaryClass", summaryClass);
        visitor.visit("maxHits", maxHits);
        visitor.visit("hits", hits);
    }

    @Override
    public void selectMembers(ObjectPredicate predicate, ObjectOperation operation) {
        for (Hit hit : hits) {
            hit.select(predicate, operation);
        }
    }

}