aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/federation/HitCountTestCase.java
blob: 28dac10a22e5a57bb1c698469bcf2b1099c6bb28 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.federation;

import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
 * @author Tony Vaagenes
 */
public class HitCountTestCase {

    @Test
    void require_that_offset_and_hits_are_adjusted_when_federating() {
        final int chain1RelevanceMultiplier = 1;
        final int chain2RelevanceMultiplier = 10;

        FederationTester tester = new FederationTester();
        tester.addSearchChain("chain1", new AddHitsWithRelevanceSearcher("chain1", chain1RelevanceMultiplier));
        tester.addSearchChain("chain2", new AddHitsWithRelevanceSearcher("chain2", chain2RelevanceMultiplier));

        Query query = new Query();
        query.setHits(5);

        query.setOffset(0);
        assertAllHitsFrom("chain2", flattenAndTrim(tester.search(query)));

        query.setOffset(5);
        assertAllHitsFrom("chain1", flattenAndTrim(tester.search(query)));
    }

    @Test
    void require_that_hit_counts_are_merged() {
        final long chain1TotalHitCount = 3;
        final long chain1DeepHitCount = 5;

        final long chain2TotalHitCount = 7;
        final long chain2DeepHitCount = 11;

        FederationTester tester = new FederationTester();
        tester.addSearchChain("chain1", new SetHitCountsSearcher(chain1TotalHitCount, chain1DeepHitCount));
        tester.addSearchChain("chain2", new SetHitCountsSearcher(chain2TotalHitCount, chain2DeepHitCount));

        Result result = tester.searchAndFill();

        assertEquals(result.getTotalHitCount(), chain1TotalHitCount + chain2TotalHitCount);
        assertEquals(result.getDeepHitCount(), chain1DeepHitCount + chain2DeepHitCount);
    }

    @Test
    void require_that_logging_hit_is_populated_with_result_count() {
        final long chain1TotalHitCount = 9;
        final long chain1DeepHitCount = 14;

        final long chain2TotalHitCount = 11;
        final long chain2DeepHitCount = 15;

        FederationTester tester = new FederationTester();
        tester.addSearchChain("chain1",
                new SetHitCountsSearcher(chain1TotalHitCount, chain1DeepHitCount));

        tester.addSearchChain("chain2",
                new SetHitCountsSearcher(chain2TotalHitCount, chain2DeepHitCount),
                new AddHitsWithRelevanceSearcher("chain1", 2));

        Query query = new Query();
        query.setOffset(2);
        query.setHits(7);
        Result result = tester.search();
        List<Hit> metaHits = getFirstMetaHitInEachGroup(result);

        Hit first = metaHits.get(0);
        assertEquals(chain1TotalHitCount, first.getField("count_total"));
        assertEquals(chain1TotalHitCount, first.getField("count_total"));
        assertEquals(1, first.getField("count_first"));
        assertEquals(0, first.getField("count_last"));

        Hit second = metaHits.get(1);
        assertEquals(chain2TotalHitCount, second.getField("count_total"));
        assertEquals(chain2TotalHitCount, second.getField("count_total"));
        assertEquals(1, second.getField("count_first"));
        assertEquals(AddHitsWithRelevanceSearcher.numHitsAdded, second.getField("count_last"));

    }

    private List<Hit> getFirstMetaHitInEachGroup(Result result) {
        List<Hit> metaHits = new ArrayList<>();
        for (Hit topLevelHit : result.hits()) {
            if (topLevelHit instanceof HitGroup) {
                for (Hit hit : (HitGroup)topLevelHit) {
                    if (hit.isMeta()) {
                        metaHits.add(hit);
                        break;
                    }
                }
            }
        }
        return metaHits;
    }

    private void assertAllHitsFrom(String chainName, HitGroup flattenedHits) {
        for (Hit hit : flattenedHits) {
            assertTrue(hit.getId().toString().startsWith(chainName));
        }
    }

    private HitGroup flattenAndTrim(Result result) {
        HitGroup flattenedHits = new HitGroup();
        result.setQuery(result.getQuery());
        flatten(result.hits(), flattenedHits);

        flattenedHits.trim(result.getQuery().getOffset(), result.getQuery().getHits());
        return flattenedHits;
    }

    private void flatten(HitGroup hits, HitGroup flattenedHits) {
        for (Hit hit : hits) {
            if (hit instanceof HitGroup) {
                flatten((HitGroup) hit, flattenedHits);
            } else {
                flattenedHits.add(hit);
            }
        }
    }
}