aboutsummaryrefslogtreecommitdiffstats
path: root/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java
blob: df01f4f2fa7d75ff5f7aadddd28db80b273e9e82 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jrt;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.Optional;

/**
 * A crypto socket for the server side of a connection that
 * auto-detects whether the connection is tls encrypted or unencrypted
 * using clever heuristics. The assumption is that the client side
 * will send at least one RPC request before expecting anything from
 * the server. The first 9 bytes are inspected to see if they look
 * like part of a tls handshake or not (RPC packet headers are 12
 * bytes).
 **/
public class MaybeTlsCryptoSocket implements CryptoSocket {

    private static final int SNOOP_SIZE = 9;

    private CryptoSocket socket;

    // 'data' is the first 9 bytes received from the client
    public static boolean looksLikeTlsToMe(byte[] data) {
        if (data.length != SNOOP_SIZE) {
            return false; // wrong data size for tls detection
        }
        if (data[0] != 22) {
            return false; // not tagged as tls handshake
        }
        if (data[1] != 3) {
            return false; // unknown major version
        }
        if ((data[2] != 1) && (data[2] != 3)) {
            return false; // unknown minor version
        }
        int frame_len = (data[3] & 0xff);
        frame_len = ((frame_len << 8) | (data[4] & 0xff));
        if (frame_len > (16384 + 2048)) {
            return false; // frame too large
        }
        if (frame_len < 4) {
            return false; // frame too small
        }
        if (data[5] != 0x1) {
            return false; // not tagges as client hello
        }
        int hello_len = (data[6] & 0xff);
        hello_len = ((hello_len << 8) | (data[7] & 0xff));
        hello_len = ((hello_len << 8) | (data[8] & 0xff));
        if ((frame_len - 4) != hello_len) {
            return false; // inconsistent sizes; frame vs client hello
        }
        return true;
    }

    private class MyCryptoSocket extends NullCryptoSocket {

        private final TransportMetrics metrics = TransportMetrics.getInstance();
        private TlsCryptoEngine factory;
        private Buffer buffer;

        MyCryptoSocket(SocketChannel channel, TlsCryptoEngine factory) {
            super(channel, true);
            this.factory = factory;
            this.buffer = new Buffer(4096);
        }

        @Override public HandshakeResult handshake() throws IOException {
            if (factory != null) {
                if (channel().read(buffer.getWritable(SNOOP_SIZE)) == -1) {
                    throw new IOException("jrt: Connection closed by peer during tls detection");
                }
                if (buffer.bytes() < SNOOP_SIZE) {
                    return HandshakeResult.NEED_READ;
                }
                byte[] data = new byte[SNOOP_SIZE];
                ByteBuffer src = buffer.getReadable();
                for (int i = 0; i < SNOOP_SIZE; i++) {
                    data[i] = src.get(i);
                }
                if (looksLikeTlsToMe(data)) {
                    TlsCryptoSocket tlsSocket = factory.createServerCryptoSocket(channel());
                    tlsSocket.injectReadData(buffer);
                    socket = tlsSocket;
                    return socket.handshake();
                } else {
                    metrics.incrementServerUnencryptedConnectionsEstablished();
                    factory = null;
                }
            }
            return HandshakeResult.DONE;
        }

        @Override public int read(ByteBuffer dst) throws IOException {
            int drainResult = drain(dst);
            if (drainResult != 0) {
                return drainResult;
            }
            return super.read(dst);
        }

        @Override public int drain(ByteBuffer dst) throws IOException {
            int cnt = 0;
            if (buffer != null) {
                ByteBuffer src = buffer.getReadable();
                while (src.hasRemaining() && dst.hasRemaining()) {
                    dst.put(src.get());
                    cnt++;
                }
                if (buffer.bytes() == 0) {
                    buffer = null;
                }
            }
            return cnt;
        }
    }

    public MaybeTlsCryptoSocket(SocketChannel channel, TlsCryptoEngine factory) {
        this.socket = new MyCryptoSocket(channel, factory);
    }

    @Override public SocketChannel channel() { return socket.channel(); }
    @Override public HandshakeResult handshake() throws IOException { return socket.handshake(); }
    @Override public void doHandshakeWork() { socket.doHandshakeWork(); }
    @Override public int getMinimumReadBufferSize() { return socket.getMinimumReadBufferSize(); }
    @Override public int read(ByteBuffer dst) throws IOException { return socket.read(dst); }
    @Override public int drain(ByteBuffer dst) throws IOException { return socket.drain(dst); }
    @Override public int write(ByteBuffer src) throws IOException { return socket.write(src); }
    @Override public FlushResult flush() throws IOException { return socket.flush(); }
    @Override public void dropEmptyBuffers() { socket.dropEmptyBuffers(); }
    @Override public Optional<SecurityContext> getSecurityContext() { return Optional.ofNullable(socket).flatMap(CryptoSocket::getSecurityContext); }
}