aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
blob: bf95fd4d0ce3f1c8c75ed89a67e1ccb1db6b70e7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.aggregation;

import com.yahoo.document.DocumentId;
import com.yahoo.document.GlobalId;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.searchlib.aggregation.hll.SparseSketch;
import com.yahoo.searchlib.expression.AddFunctionNode;
import com.yahoo.searchlib.expression.AttributeNode;
import com.yahoo.searchlib.expression.CatFunctionNode;
import com.yahoo.searchlib.expression.ConstantNode;
import com.yahoo.searchlib.expression.DebugWaitFunctionNode;
import com.yahoo.searchlib.expression.DivideFunctionNode;
import com.yahoo.searchlib.expression.DocumentFieldNode;
import com.yahoo.searchlib.expression.ExpressionNode;
import com.yahoo.searchlib.expression.FixedWidthBucketFunctionNode;
import com.yahoo.searchlib.expression.FloatBucketResultNode;
import com.yahoo.searchlib.expression.FloatBucketResultNodeVector;
import com.yahoo.searchlib.expression.FloatResultNode;
import com.yahoo.searchlib.expression.GetDocIdNamespaceSpecificFunctionNode;
import com.yahoo.searchlib.expression.IntegerBucketResultNode;
import com.yahoo.searchlib.expression.IntegerBucketResultNodeVector;
import com.yahoo.searchlib.expression.IntegerResultNode;
import com.yahoo.searchlib.expression.MD5BitFunctionNode;
import com.yahoo.searchlib.expression.MaxFunctionNode;
import com.yahoo.searchlib.expression.MinFunctionNode;
import com.yahoo.searchlib.expression.ModuloFunctionNode;
import com.yahoo.searchlib.expression.MultiplyFunctionNode;
import com.yahoo.searchlib.expression.NegateFunctionNode;
import com.yahoo.searchlib.expression.NormalizeSubjectFunctionNode;
import com.yahoo.searchlib.expression.RangeBucketPreDefFunctionNode;
import com.yahoo.searchlib.expression.RawBucketResultNode;
import com.yahoo.searchlib.expression.RawBucketResultNodeVector;
import com.yahoo.searchlib.expression.RawResultNode;
import com.yahoo.searchlib.expression.ReverseFunctionNode;
import com.yahoo.searchlib.expression.SortFunctionNode;
import com.yahoo.searchlib.expression.StringBucketResultNode;
import com.yahoo.searchlib.expression.StringBucketResultNodeVector;
import com.yahoo.searchlib.expression.StringResultNode;
import com.yahoo.searchlib.expression.TimeStampFunctionNode;
import com.yahoo.searchlib.expression.XorBitFunctionNode;
import com.yahoo.searchlib.expression.XorFunctionNode;
import com.yahoo.searchlib.expression.ZCurveFunctionNode;
import com.yahoo.vespa.objects.BufferSerializer;
import com.yahoo.vespa.objects.Identifiable;
import com.yahoo.vespa.objects.ObjectDumper;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

import static org.junit.Assert.fail;

/**
 * Tests serialization compatibility across Java and C++. The comparison is performed by comparing serialized Java
 * object graphs with the content of specific binary files. C++ unit tests serializes
 * identical data structures into these files.
 * Note: This test relies heavily on proper implementation of {@link Object#equals(Object)}!
 */
public class GroupingSerializationTest {

    @BeforeClass
    public static void forceLoadingOfSerializableClasses() {
        com.yahoo.searchlib.aggregation.ForceLoad.forceLoad();
        com.yahoo.searchlib.expression.ForceLoad.forceLoad();
    }

    @Test
    public void testResultTypes() throws IOException {
        try (SerializationTester t = new SerializationTester("testResultTypes")) {
            t.assertMatch(new IntegerResultNode(7));
            t.assertMatch(new FloatResultNode(7.3));
            t.assertMatch(new StringResultNode("7.3"));
            t.assertMatch(new StringResultNode(
                    new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c},
                               StandardCharsets.UTF_8)));
            t.assertMatch(new RawResultNode(new byte[]{'7', '.', '4'}));
            t.assertMatch(new IntegerBucketResultNode());
            t.assertMatch(new FloatBucketResultNode());
            t.assertMatch(new IntegerBucketResultNode(10, 20));
            t.assertMatch(new FloatBucketResultNode(10, 20));
            t.assertMatch(new StringBucketResultNode("10.0", "20.0"));
            t.assertMatch(new RawBucketResultNode(
                    new RawResultNode(new byte[]{1, 0, 0}),
                    new RawResultNode(new byte[]{1, 1, 0})));
            t.assertMatch(new IntegerBucketResultNodeVector()
                    .add(new IntegerBucketResultNode(878, 3246823)));
            t.assertMatch(new FloatBucketResultNodeVector()
                    .add(new FloatBucketResultNode(878, 3246823)));
            t.assertMatch(new StringBucketResultNodeVector()
                    .add(new StringBucketResultNode("878", "3246823")));
            t.assertMatch(new RawBucketResultNodeVector()
                    .add(new RawBucketResultNode(
                            new RawResultNode(new byte[]{1, 0, 0}),
                            new RawResultNode(new byte[]{1, 1, 0}))));
        }

    }

    @Test
    public void testSpecialNodes() throws IOException {
        try (SerializationTester t = new SerializationTester("testSpecialNodes")) {
            t.assertMatch(new AttributeNode("testattribute"));
            t.assertMatch(new DocumentFieldNode("testdocumentfield"));
            t.assertMatch(new GetDocIdNamespaceSpecificFunctionNode(new IntegerResultNode(7)));
        }
    }

    @Test
    public void testFunctionNodes() throws IOException {
        try (SerializationTester t = new SerializationTester("testFunctionNodes")) {
            t.assertMatch(new AddFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new XorFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new MultiplyFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new DivideFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new ModuloFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new MinFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new MaxFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new TimeStampFunctionNode(new ConstantNode(new IntegerResultNode(7)),
                    TimeStampFunctionNode.TimePart.Hour, true));
            t.assertMatch(new ZCurveFunctionNode(new ConstantNode(new IntegerResultNode(7)),
                    ZCurveFunctionNode.Dimension.X));
            t.assertMatch(new ZCurveFunctionNode(new ConstantNode(new IntegerResultNode(7)),
                    ZCurveFunctionNode.Dimension.Y));
            t.assertMatch(new NegateFunctionNode(new ConstantNode(new IntegerResultNode(7))));
            t.assertMatch(new SortFunctionNode(new ConstantNode(new IntegerResultNode(7))));
            t.assertMatch(new NormalizeSubjectFunctionNode(new ConstantNode(
                    new StringResultNode("foo"))));
            t.assertMatch(new ReverseFunctionNode(new ConstantNode(new IntegerResultNode(7))));
            t.assertMatch(new MD5BitFunctionNode(new ConstantNode(new IntegerResultNode(7)), 64));
            t.assertMatch(new XorBitFunctionNode(new ConstantNode(new IntegerResultNode(7)), 64));
            t.assertMatch(new CatFunctionNode()
                    .addArg(new ConstantNode(new IntegerResultNode(7)))
                    .addArg(new ConstantNode(new IntegerResultNode(8)))
                    .addArg(new ConstantNode(new IntegerResultNode(9))));
            t.assertMatch(new FixedWidthBucketFunctionNode());
            t.assertMatch(new FixedWidthBucketFunctionNode().addArg(new AttributeNode("foo")));
            t.assertMatch(new FixedWidthBucketFunctionNode(new IntegerResultNode(10), new AttributeNode("foo")));
            t.assertMatch(new FixedWidthBucketFunctionNode(new FloatResultNode(10.0), new AttributeNode("foo")));
            t.assertMatch(new RangeBucketPreDefFunctionNode());
            t.assertMatch(new RangeBucketPreDefFunctionNode().addArg(new AttributeNode("foo")));
            t.assertMatch(new DebugWaitFunctionNode(new ConstantNode(new IntegerResultNode(5)),
                    3.3, false));
        }

    }

    @Test
    public void testAggregatorResults() throws IOException {
        try (SerializationTester t = new SerializationTester("testAggregatorResults")) {
            t.assertMatch(new SumAggregationResult(new IntegerResultNode(7))
                    .setExpression(new AttributeNode("attributeA")));
            t.assertMatch(new XorAggregationResult()
                    .setXor(7)
                    .setExpression(new AttributeNode("attributeA")));
            t.assertMatch(new CountAggregationResult()
                    .setCount(7)
                    .setExpression(new AttributeNode("attributeA")));
            t.assertMatch(new MinAggregationResult(new IntegerResultNode(7))
                    .setExpression(new AttributeNode("attributeA")));
            t.assertMatch(new MaxAggregationResult(new IntegerResultNode(7))
                    .setExpression(new AttributeNode("attributeA")));
            t.assertMatch(new AverageAggregationResult(new IntegerResultNode(7), 0)
                    .setExpression(new AttributeNode("attributeA")));
            SparseSketch sketch = new SparseSketch();
            sketch.aggregate(1955583074);
            t.assertMatch(new ExpressionCountAggregationResult(sketch, s -> 42)
                    .setExpression(new ConstantNode(new IntegerResultNode(67))));
            t.assertMatch(new StandardDeviationAggregationResult(1, 67, 67 * 67)
                    .setExpression(new ConstantNode(new IntegerResultNode(67))));
        }
    }

    @Test
    public void testHitCollection() throws IOException {
        try (SerializationTester t = new SerializationTester("testHitCollection")) {
            t.assertMatch(new FS4Hit(0, new GlobalId(new byte[GlobalId.LENGTH]), 0, -1));
            t.assertMatch(new FS4Hit(0, createGlobalId(100), 50.0, -1));
            t.assertMatch(new VdsHit());
            //TODO Verify the two structures below
            t.assertMatch(new VdsHit("100", new byte[0], 50.0));
            t.assertMatch(new VdsHit("100", "rawsummary".getBytes(), 50.0));
            t.assertMatch(new HitsAggregationResult());
            t.assertMatch(new HitsAggregationResult()
                    .setMaxHits(5)
                    .addHit(new FS4Hit(0, createGlobalId(10), 1.0, -1))
                    .addHit(new FS4Hit(0, createGlobalId(20), 2.0, -1))
                    .addHit(new FS4Hit(0, createGlobalId(30), 3.0, -1))
                    .addHit(new FS4Hit(0, createGlobalId(40), 4.0, -1))
                    .addHit(new FS4Hit(0, createGlobalId(50), 5.0, -1))
                    .setExpression(new ConstantNode(new IntegerResultNode(5))));
            t.assertMatch(new HitsAggregationResult()
                    .setMaxHits(3)
                    .addHit(new FS4Hit(0, createGlobalId(10), 1.0, 100))
                    .addHit(new FS4Hit(0, createGlobalId(20), 2.0, 200))
                    .addHit(new FS4Hit(0, createGlobalId(30), 3.0, 300))
                    .setExpression(new ConstantNode(new IntegerResultNode(5))));
            //TODO Verify content
            t.assertMatch(new HitsAggregationResult()
                    .setMaxHits(3)
                    .addHit(new VdsHit("10", "100".getBytes(), 1.0))
                    .addHit(new VdsHit("20", "200".getBytes(), 2.0))
                    .addHit(new VdsHit("30", "300".getBytes(), 3.0))
                    .setExpression(new ConstantNode(new IntegerResultNode(5))));
        }
    }

    @Test
    public void testGroupingLevel() throws IOException {
        try (SerializationTester t = new SerializationTester("testGroupingLevel")) {
            GroupingLevel groupingLevel = new GroupingLevel();
            groupingLevel.setMaxGroups(100)
                    .setExpression(createDummyExpression())
                    .getGroupPrototype()
                    .addAggregationResult(
                            new SumAggregationResult()
                                    .setExpression(createDummyExpression()));
            t.assertMatch(groupingLevel);
        }
    }

    @Test
    public void testGroup() throws IOException {
        try (SerializationTester t = new SerializationTester("testGroup")) {
            t.assertMatch(new Group());
            t.assertMatch(new Group().setId(new IntegerResultNode(50))
                    .setRank(10));
            t.assertMatch(new Group().setId(new IntegerResultNode(100))
                    .addChild(new Group().setId(new IntegerResultNode(110)))
                    .addChild(new Group().setId(new IntegerResultNode(120))
                            .setRank(20.5)
                            .addAggregationResult(new SumAggregationResult()
                                    .setExpression(createDummyExpression()))
                            .addAggregationResult(new SumAggregationResult()
                                    .setExpression(createDummyExpression())))
                    .addChild(new Group().setId(new IntegerResultNode(130))
                            .addChild(new Group().setId(new IntegerResultNode(131)))));
        }
    }

    @Test
    public void testGrouping() throws IOException {
        try (SerializationTester t = new SerializationTester("testGrouping")) {
            t.assertMatch(new Grouping());

            GroupingLevel level1 = new GroupingLevel();
            level1.setMaxGroups(100)
                  .setExpression(createDummyExpression())
                  .getGroupPrototype()
                      .addAggregationResult(
                              new SumAggregationResult()
                                      .setExpression(createDummyExpression()));
            GroupingLevel level2 = new GroupingLevel();
            level2.setMaxGroups(10)
                    .setExpression(createDummyExpression())
                    .getGroupPrototype()
                        .addAggregationResult(
                                new SumAggregationResult()
                                        .setExpression(createDummyExpression()))
                        .addAggregationResult(
                                new SumAggregationResult()
                                        .setExpression(createDummyExpression()));
            t.assertMatch(new Grouping()
                    .addLevel(level1)
                    .addLevel(level2));

            GroupingLevel level3 = new GroupingLevel();
            level3.setExpression(new AttributeNode("folder"))
                    .getGroupPrototype()
                    .addAggregationResult(
                            new XorAggregationResult()
                                    .setExpression(new MD5BitFunctionNode(new AttributeNode("docid"), 64)))
                    .addAggregationResult(
                            new SumAggregationResult()
                                    .setExpression(new MinFunctionNode()
                                            .addArg(new AttributeNode("attribute1"))
                                            .addArg(new AttributeNode("attribute2"))))
                    .addAggregationResult(
                            new XorAggregationResult()
                                    .setExpression(
                                            new XorBitFunctionNode(new CatFunctionNode()
                                                    .addArg(new GetDocIdNamespaceSpecificFunctionNode(new StringResultNode("")))
                                                    .addArg(new DocumentFieldNode("folder"))
                                                    .addArg(new DocumentFieldNode("flags")), 64)));
            t.assertMatch(new Grouping()
                    .addLevel(level3));
        }
    }


    private static GlobalId createGlobalId(int docId) {
        return new GlobalId(
                new DocumentId(String.format("id:test:type::%d", docId)).getGlobalId());
    }

    private static ExpressionNode createDummyExpression() {
        return new AddFunctionNode()
                .addArg(new ConstantNode(new IntegerResultNode(2)))
                .addArg(new ConstantNode(new IntegerResultNode(2)));
    }

    private static class SerializationTester implements AutoCloseable {

        private static final String FILE_PATH = "src/test/files";

        private final DataInputStream in;
        private final String fileName;

        public SerializationTester(String fileName) throws IOException {
            this.fileName = fileName;
            this.in = new DataInputStream(
                    new BufferedInputStream(
                            new FileInputStream(
                                    new File(FILE_PATH, fileName))));
        }

        public SerializationTester assertMatch(Identifiable expectedObject) throws IOException {
            int length = readLittleEndianInt(in);
            byte[] originalData = new byte[length];
            in.readFully(originalData);
            Identifiable deserializedObject = Identifiable.create(new BufferSerializer(originalData));

            if (!deserializedObject.equals(expectedObject)) {
                fail(String.format("Serialized object in file '%s' does not equal expected values.\n" +
                                "==================================================\n" +
                                "Expected:\n" +
                                "==================================================\n" +
                                "%s\n" +
                                "==================================================\n" +
                                "Actual:\n" +
                                "==================================================\n" +
                                "%s\n" +
                                "==================================================\n",
                        fileName, dumpObject(expectedObject), dumpObject(deserializedObject)));
            }
            GrowableByteBuffer buffer = new GrowableByteBuffer(1024 * 8);
            BufferSerializer serializer = new BufferSerializer(buffer);
            deserializedObject.serializeWithId(serializer);
            buffer.flip();

            byte[] newData = new byte[buffer.limit()];
            buffer.get(newData);
            if (!Arrays.equals(newData, originalData)) {
                fail(String.format("Serialized object data does not match the original serialized data from file.\n" +
                                "==================================================\n" +
                                "Original:\n" +
                                "==================================================\n" +
                                "%s\n" +
                                "==================================================\n" +
                                "Serialized:\n" +
                                "==================================================\n" +
                                "%s\n" +
                                "==================================================\n",
                        toHexString(originalData), toHexString(newData)));
            }
            return this;
        }

        private static int readLittleEndianInt(DataInputStream in) throws IOException {
            byte[] data = new byte[4];
            in.readFully(data);
            ByteBuffer buffer = ByteBuffer.wrap(data);
            buffer.order(ByteOrder.LITTLE_ENDIAN);
            return buffer.getInt();
        }

        private static String dumpObject(Identifiable obj) {
            ObjectDumper dumper = new ObjectDumper();
            obj.visitMembers(dumper);
            return dumper.toString();
        }

        @Override
        public void close() throws IOException {
            int bytesLeft = 0;
            while (in.read() != -1)
                bytesLeft++;
            in.close();
            if (bytesLeft > 0)
                fail(FILE_PATH + "/" + fileName + " has " + bytesLeft + " bytes left. " +
                     "Did you forget to deserialize an object on Java side?");
        }

        private static String toHexString(byte[] data) {
            char[] table = "0123456789ABCDEF".toCharArray();
            StringBuilder builder = new StringBuilder();
            builder.append("(").append(data.length).append(" bytes)");
            for (int i = 0; i < data.length; i++) {
                if (i % 16 == 0) {
                    builder.append("\n");
                }
                builder.append(table[(data[i] >> 4) & 0xf]);
                builder.append(table[data[i] & 0xf]);
                builder.append(" ");
            }
            return builder.toString();
        }


    }

}