aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java
blob: 45ad361a21434d59da6d850ea7f69da0b7d8618b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.search.dispatch.rpc;

import ai.vespa.searchlib.searchprotocol.protobuf.SearchProtocol;
import com.google.common.collect.ImmutableMap;
import com.yahoo.compress.CompressionType;
import com.yahoo.prelude.fastsearch.FastHit;
import com.yahoo.prelude.fastsearch.VespaBackEndSearcher;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.dispatch.searchcluster.Node;
import com.yahoo.search.searchchain.Execution;
import org.junit.Test;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * @author ollivir
 */
public class RpcSearchInvokerTest {

    @Test
    public void testProtobufSerialization() throws IOException {
        var compressionTypeHolder = new AtomicReference<CompressionType>();
        var payloadHolder = new AtomicReference<byte[]>();
        var lengthHolder = new AtomicInteger();
        var mockClient = parameterCollectorClient(compressionTypeHolder, payloadHolder, lengthHolder);
        var mockPool = new RpcResourcePool(ImmutableMap.of(7, mockClient.createConnection("foo", 123)));
        var invoker = new RpcSearchInvoker(mockSearcher(), new Node(7, "seven", 1), mockPool, 1000);

        Query q = new Query("search/?query=test&hits=10&offset=3");
        RpcSearchInvoker.RpcContext context = (RpcSearchInvoker.RpcContext) invoker.sendSearchRequest(q, null);
        assertEquals(lengthHolder.get(), context.compressedPayload.uncompressedSize());
        assertSame(context.compressedPayload.data(), payloadHolder.get());

        var bytes = mockPool.compressor().decompress(payloadHolder.get(), compressionTypeHolder.get(), lengthHolder.get());
        var request = SearchProtocol.SearchRequest.newBuilder().mergeFrom(bytes).build();

        assertEquals(10, request.getHits());
        assertEquals(3, request.getOffset());
        assertTrue(request.getQueryTreeBlob().size() > 0);

        var invoker2 = new RpcSearchInvoker(mockSearcher(), new Node(8, "eight", 1), mockPool, 1000);
        RpcSearchInvoker.RpcContext context2 = (RpcSearchInvoker.RpcContext)invoker2.sendSearchRequest(q, context);
        assertSame(context, context2);
        assertEquals(lengthHolder.get(), context.compressedPayload.uncompressedSize());
        assertSame(context.compressedPayload.data(), payloadHolder.get());
    }

    @Test
    public void testProtobufSerializationWithMaxHitsSet() throws IOException {
        int maxHits = 5;
        var compressionTypeHolder = new AtomicReference<CompressionType>();
        var payloadHolder = new AtomicReference<byte[]>();
        var lengthHolder = new AtomicInteger();
        var mockClient = parameterCollectorClient(compressionTypeHolder, payloadHolder, lengthHolder);
        var mockPool = new RpcResourcePool(ImmutableMap.of(7, mockClient.createConnection("foo", 123)));
        var invoker = new RpcSearchInvoker(mockSearcher(), new Node(7, "seven", 1), mockPool, maxHits);

        Query q = new Query("search/?query=test&hits=10&offset=3");
        invoker.sendSearchRequest(q, null);

        var bytes = mockPool.compressor().decompress(payloadHolder.get(), compressionTypeHolder.get(), lengthHolder.get());
        var request = SearchProtocol.SearchRequest.newBuilder().mergeFrom(bytes).build();

        assertEquals(maxHits, request.getHits());
    }

    private Client parameterCollectorClient(AtomicReference<CompressionType> compressionTypeHolder, AtomicReference<byte[]> payloadHolder,
            AtomicInteger lengthHolder) {
        return new Client() {
            @Override
            public void close() { }
            @Override
            public NodeConnection createConnection(String hostname, int port) {
                return new NodeConnection() {
                    @Override
                    public void request(String rpcMethod, CompressionType compression, int uncompressedLength, byte[] compressedPayload,
                            ResponseReceiver responseReceiver, double timeoutSeconds) {
                        compressionTypeHolder.set(compression);
                        payloadHolder.set(compressedPayload);
                        lengthHolder.set(uncompressedLength);
                    }

                    @Override
                    public void close() { }
                };
            }
        };
    }

    private VespaBackEndSearcher mockSearcher() {
        return new VespaBackEndSearcher() {
            @Override
            protected Result doSearch2(Query query, Execution execution) {
                fail("Unexpected call");
                return null;
            }

            @Override
            protected void doPartialFill(Result result, String summaryClass) {
                fail("Unexpected call");
            }
        };
    }

}