diff options
author | Bjørn Christian Seime <bjorn.christian@seime.no> | 2018-09-25 17:31:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-25 17:31:27 +0200 |
commit | f627463a8100090ec109d27c3aeb439a3395a34f (patch) | |
tree | 674025004f825c9cc12a075f992c0b2d1d45509e | |
parent | 87fe13f4f9bf6c2ca9acaab590e74a67cc11eb26 (diff) | |
parent | eaf61679b8989895eb183332f92b430fab9d3dfd (diff) |
Merge pull request #7089 from vespa-engine/havardpe/jrt-tls-mixed-mode
added support for auto-detecting tls for incoming connections
-rw-r--r-- | jrt/src/com/yahoo/jrt/MaybeTlsCryptoEngine.java | 34 | ||||
-rw-r--r-- | jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java | 126 | ||||
-rw-r--r-- | jrt/src/com/yahoo/jrt/TlsCryptoSocket.java | 5 | ||||
-rw-r--r-- | jrt/tests/com/yahoo/jrt/EchoTest.java | 4 | ||||
-rw-r--r-- | jrt/tests/com/yahoo/jrt/SessionTest.java | 2 | ||||
-rw-r--r-- | jrt/tests/com/yahoo/jrt/TlsDetectionTest.java | 95 |
6 files changed, 264 insertions, 2 deletions
diff --git a/jrt/src/com/yahoo/jrt/MaybeTlsCryptoEngine.java b/jrt/src/com/yahoo/jrt/MaybeTlsCryptoEngine.java new file mode 100644 index 00000000000..8cb560246e8 --- /dev/null +++ b/jrt/src/com/yahoo/jrt/MaybeTlsCryptoEngine.java @@ -0,0 +1,34 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import java.nio.channels.SocketChannel; + +/** + * A crypto engine that supports both tls encrypted connections and + * unencrypted connections. The use of tls for incoming connections is + * auto-detected using clever heuristics. The use of tls for outgoing + * connections is controlled by the useTls flag given to the + * constructor. + **/ +public class MaybeTlsCryptoEngine implements CryptoEngine { + + private final TlsCryptoEngine tlsEngine; + private final boolean useTls; + + public MaybeTlsCryptoEngine(TlsCryptoEngine tlsEngine, boolean useTls) { + this.tlsEngine = tlsEngine; + this.useTls = useTls; + } + + @Override public CryptoSocket createCryptoSocket(SocketChannel channel, boolean isServer) { + if (isServer) { + return new MaybeTlsCryptoSocket(channel, tlsEngine); + } else if (useTls) { + return tlsEngine.createCryptoSocket(channel, false); + } else { + return new NullCryptoSocket(channel); + } + } + + @Override public String toString() { return "MaybeTlsCryptoEngine(useTls:" + useTls + ")"; } +} diff --git a/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java new file mode 100644 index 00000000000..7cedbcda9a1 --- /dev/null +++ b/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java @@ -0,0 +1,126 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; + +/** + * A crypto socket for the server side of a connection that + * auto-detects whether the connection is tls encrypted or unencrypted + * using clever heuristics. The assumption is that the client side + * will send at least one RPC request before expecting anything from + * the server. The first 9 bytes are inspected to see if they look + * like part of a tls handshake or not (RPC packet headers are 12 + * bytes). + **/ +public class MaybeTlsCryptoSocket implements CryptoSocket { + + private static final int snoop_size = 9; + + private CryptoSocket socket; + + // 'data' is the first 9 bytes received from the client + public static boolean looksLikeTlsToMe(byte[] data) { + if (data.length != snoop_size) { + return false; // wrong data size for tls detection + } + if (data[0] != 22) { + return false; // not tagged as tls handshake + } + if (data[1] != 3) { + return false; // unknown major version + } + if ((data[2] != 1) && (data[2] != 3)) { + return false; // unknown minor version + } + int frame_len = (data[3] & 0xff); + frame_len = ((frame_len << 8) | (data[4] & 0xff)); + if (frame_len > (16384 + 2048)) { + return false; // frame too large + } + if (frame_len < 4) { + return false; // frame too small + } + if (data[5] != 0x1) { + return false; // not tagges as client hello + } + int hello_len = (data[6] & 0xff); + hello_len = ((hello_len << 8) | (data[7] & 0xff)); + hello_len = ((hello_len << 8) | (data[8] & 0xff)); + if ((frame_len - 4) != hello_len) { + return false; // inconsistent sizes; frame vs client hello + } + return true; + } + + private class MyCryptoSocket extends NullCryptoSocket { + + private TlsCryptoEngine factory; + private Buffer buffer; + + MyCryptoSocket(SocketChannel channel, TlsCryptoEngine factory) { + super(channel); + this.factory = factory; + this.buffer = new Buffer(4096); + } + + @Override public HandshakeResult handshake() throws IOException { + if (factory != null) { + channel().read(buffer.getWritable(snoop_size)); + if (buffer.bytes() < snoop_size) { + return HandshakeResult.NEED_READ; + } + byte[] data = new byte[snoop_size]; + ByteBuffer src = buffer.getReadable(); + for (int i = 0; i < snoop_size; i++) { + data[i] = src.get(i); + } + if (looksLikeTlsToMe(data)) { + TlsCryptoSocket tlsSocket = factory.createCryptoSocket(channel(), true); + tlsSocket.injectReadData(buffer); + socket = tlsSocket; + return socket.handshake(); + } else { + factory = null; + } + } + return HandshakeResult.DONE; + } + + @Override public int read(ByteBuffer dst) throws IOException { + int drainResult = drain(dst); + if (drainResult != 0) { + return drainResult; + } + return super.read(dst); + } + + @Override public int drain(ByteBuffer dst) throws IOException { + int cnt = 0; + if (buffer != null) { + ByteBuffer src = buffer.getReadable(); + while (src.hasRemaining() && dst.hasRemaining()) { + dst.put(src.get()); + cnt++; + } + if (buffer.bytes() == 0) { + buffer = null; + } + } + return cnt; + } + } + + public MaybeTlsCryptoSocket(SocketChannel channel, TlsCryptoEngine factory) { + this.socket = new MyCryptoSocket(channel, factory); + } + + @Override public SocketChannel channel() { return socket.channel(); } + @Override public HandshakeResult handshake() throws IOException { return socket.handshake(); } + @Override public int getMinimumReadBufferSize() { return socket.getMinimumReadBufferSize(); } + @Override public int read(ByteBuffer dst) throws IOException { return socket.read(dst); } + @Override public int drain(ByteBuffer dst) throws IOException { return socket.drain(dst); } + @Override public int write(ByteBuffer src) throws IOException { return socket.write(src); } + @Override public FlushResult flush() throws IOException { return socket.flush(); } +} diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java index 3db54811f9e..96aca622af4 100644 --- a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java +++ b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java @@ -46,6 +46,11 @@ public class TlsCryptoSocket implements CryptoSocket { this.handshakeState = HandshakeState.NOT_STARTED; } + // inject pre-read data into the read pipeline (typically called by MaybeTlsCryptoSocket) + public void injectReadData(Buffer data) { + unwrapBuffer.getWritable(data.bytes()).put(data.getReadable()); + } + @Override public SocketChannel channel() { return channel; diff --git a/jrt/tests/com/yahoo/jrt/EchoTest.java b/jrt/tests/com/yahoo/jrt/EchoTest.java index a91ac117f41..ff036af183b 100644 --- a/jrt/tests/com/yahoo/jrt/EchoTest.java +++ b/jrt/tests/com/yahoo/jrt/EchoTest.java @@ -23,7 +23,9 @@ public class EchoTest { @Parameter public CryptoEngine crypto; @Parameters(name = "{0}") public static Object[] engines() { - return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; + return new Object[] { new NullCryptoEngine(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()), + new MaybeTlsCryptoEngine(new TlsCryptoEngine(createTestSslContext()), false), + new MaybeTlsCryptoEngine(new TlsCryptoEngine(createTestSslContext()), true) }; } @Before diff --git a/jrt/tests/com/yahoo/jrt/SessionTest.java b/jrt/tests/com/yahoo/jrt/SessionTest.java index 63d14601b6e..2d8f9188623 100644 --- a/jrt/tests/com/yahoo/jrt/SessionTest.java +++ b/jrt/tests/com/yahoo/jrt/SessionTest.java @@ -19,7 +19,7 @@ public class SessionTest implements SessionHandler { @Parameter public CryptoEngine crypto; @Parameters(name = "{0}") public static Object[] engines() { - return new Object[] { CryptoEngine.createDefault(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; + return new Object[] { new NullCryptoEngine(), new XorCryptoEngine(), new TlsCryptoEngine(createTestSslContext()) }; } private static class Session { diff --git a/jrt/tests/com/yahoo/jrt/TlsDetectionTest.java b/jrt/tests/com/yahoo/jrt/TlsDetectionTest.java new file mode 100644 index 00000000000..9bd37e25772 --- /dev/null +++ b/jrt/tests/com/yahoo/jrt/TlsDetectionTest.java @@ -0,0 +1,95 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +public class TlsDetectionTest { + + static private String message(byte[] data, boolean actual) { + String msg = "["; + String delimiter = ""; + for (byte b: data) { + msg += delimiter + (b & 0xff); + delimiter = ", "; + } + if (actual) { + msg += "] wrongfully detected as tls"; + } else { + msg += "] wrongfully rejected as not tls"; + } + return msg; + } + + static private void checkTls(boolean expect, int ... values) { + byte[] data = new byte[values.length]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) values[i]; + } + boolean actual = MaybeTlsCryptoSocket.looksLikeTlsToMe(data); + if(actual != expect) { + throw new AssertionError(message(data, actual)); + } + } + + @org.junit.Test public void testValidHandshake() { + checkTls(true, 22, 3, 1, 10, 255, 1, 0, 10, 251); + checkTls(true, 22, 3, 3, 10, 255, 1, 0, 10, 251); + } + + @org.junit.Test public void testDataOfWrongSize() { + checkTls(false, 22, 3, 1, 10, 255, 1, 0, 10); + checkTls(false, 22, 3, 1, 10, 255, 1, 0, 10, 251, 0); + } + + @org.junit.Test public void testDataNotTaggedAsHandshake() { + checkTls(false, 23, 3, 1, 10, 255, 1, 0, 10, 251); + } + + @org.junit.Test public void testDataWithBadMajorVersion() { + checkTls(false, 22, 0, 1, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 1, 1, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 2, 1, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 4, 1, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 5, 1, 10, 255, 1, 0, 10, 251); + } + + @org.junit.Test public void testDataWithBadMinorVersion() { + checkTls(false, 22, 3, 0, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 3, 2, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 3, 4, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 3, 5, 10, 255, 1, 0, 10, 251); + } + + @org.junit.Test public void testDataNotTaggedAsClientHello() { + checkTls(false, 22, 3, 1, 10, 255, 0, 0, 10, 251); + checkTls(false, 22, 3, 1, 10, 255, 2, 0, 10, 251); + } + + @org.junit.Test public void testFrameSizeLimits() { + checkTls(false, 22, 3, 1, 255, 255, 1, 0, 255, 251); // max + checkTls(false, 22, 3, 1, 72, 1, 1, 0, 71, 253); // 18k + 1 + checkTls(true, 22, 3, 1, 72, 0, 1, 0, 71, 252); // 18k + checkTls(true, 22, 3, 1, 0, 4, 1, 0, 0, 0); // 4 + checkTls(false, 22, 3, 1, 0, 3, 1, 0, 0, 0); // 3 - capped + checkTls(false, 22, 3, 1, 0, 3, 1, 255, 255, 255); // 3 - wrapped + } + + @org.junit.Test public void testFrameAndClientHelloSizeRelationship() { + checkTls(true, 22, 3, 1, 10, 255, 1, 0, 10, 251); + checkTls(false, 22, 3, 1, 10, 255, 1, 1, 10, 251); + checkTls(false, 22, 3, 1, 10, 255, 1, 2, 10, 251); + + checkTls(false, 22, 3, 1, 10, 5, 1, 0, 10, 0); + checkTls(true, 22, 3, 1, 10, 5, 1, 0, 10, 1); + checkTls(false, 22, 3, 1, 10, 5, 1, 0, 10, 2); + + checkTls(false, 22, 3, 1, 10, 5, 1, 0, 9, 1); + checkTls(true, 22, 3, 1, 10, 5, 1, 0, 10, 1); + checkTls(false, 22, 3, 1, 10, 5, 1, 0, 11, 1); + + checkTls(true, 22, 3, 1, 10, 5, 1, 0, 10, 1); + checkTls(true, 22, 3, 1, 10, 4, 1, 0, 10, 0); + checkTls(true, 22, 3, 1, 10, 3, 1, 0, 9, 255); + checkTls(true, 22, 3, 1, 10, 2, 1, 0, 9, 254); + checkTls(true, 22, 3, 1, 10, 1, 1, 0, 9, 253); + checkTls(true, 22, 3, 1, 10, 0, 1, 0, 9, 252); + } +} |