diff options
author | Bjørn Christian Seime <bjorncs@oath.com> | 2018-09-04 12:43:05 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@oath.com> | 2018-09-04 12:43:05 +0200 |
commit | 84aa3d47006366a91f6c8a604bf9c19718b43c66 (patch) | |
tree | 4654a1e0668a8d69b9b54f1397ec9e7fa4e963db /jrt | |
parent | f7ab4bc80b354d5db31119772a801034d72e0ec6 (diff) |
Rewrite wrap+unwrap to remove use of magic return values
- Wrap/unwrap for handshake returns true for success, false otherwise
- Wrap/unwrap for application data returns bytes consumed/produced
- Do not throw exception on overflow for unwrap
- Misc changes to reduce code duplication
Diffstat (limited to 'jrt')
-rw-r--r-- | jrt/src/com/yahoo/jrt/TlsCryptoSocket.java | 118 |
1 files changed, 71 insertions, 47 deletions
diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java index 36d470bf67b..eaa28b37aab 100644 --- a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java +++ b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java @@ -20,6 +20,8 @@ import static javax.net.ssl.SSLEngineResult.*; */ public class TlsCryptoSocket implements CryptoSocket { + private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0); + private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName()); private enum HandshakeState { NOT_STARTED, NEED_READ, NEED_WRITE, COMPLETED } @@ -69,7 +71,7 @@ public class TlsCryptoSocket implements CryptoSocket { case COMPLETED: return HandshakeState.COMPLETED; default: - throw new IllegalStateException("Unhandled state: " + state); + throw unhandledStateException(state); } while (true) { @@ -84,13 +86,10 @@ public class TlsCryptoSocket implements CryptoSocket { break; case NEED_UNWRAP: if (hasWrapBufferMoreData()) return HandshakeState.NEED_WRITE; - int bytesUnwrapped = sslEngineUnwrap(handshakeDummyBuffer); - if (handshakeDummyBuffer.position() > 0) throw new SSLException("Got application data in handshake unwrap: " + handshakeDummyBuffer); - if (bytesUnwrapped == -1) return HandshakeState.NEED_READ; + if (!handshakeUnwrap()) return HandshakeState.NEED_READ; break; case NEED_WRAP: - int bytesWrapped = sslEngineWrap(handshakeDummyBuffer); - if (bytesWrapped == -1) return HandshakeState.NEED_WRITE; + if (!handshakeWrap()) return HandshakeState.NEED_WRITE; break; default: throw new IllegalStateException("Unexpected handshake status: " + sslEngine.getHandshakeStatus()); @@ -107,7 +106,7 @@ public class TlsCryptoSocket implements CryptoSocket { case COMPLETED: return HandshakeResult.DONE; default: - throw new IllegalStateException("Unhandled state: " + state); + throw unhandledStateException(state); } } @@ -119,7 +118,7 @@ public class TlsCryptoSocket implements CryptoSocket { @Override public int read(ByteBuffer dst) throws IOException { verifyHandshakeCompleted(); - int bytesUnwrapped = sslEngineAppDataUnwrap(dst); + int bytesUnwrapped = applicationDataUnwrap(dst); if (bytesUnwrapped > 0) return bytesUnwrapped; int bytesRead = channelRead(); @@ -133,26 +132,18 @@ public class TlsCryptoSocket implements CryptoSocket { int totalBytesUnwrapped = 0; int bytesUnwrapped; do { - bytesUnwrapped = sslEngineAppDataUnwrap(dst); + bytesUnwrapped = applicationDataUnwrap(dst); totalBytesUnwrapped += bytesUnwrapped; } while (bytesUnwrapped > 0); return totalBytesUnwrapped; } - private int sslEngineAppDataUnwrap(ByteBuffer dst) throws IOException { - int bytesUnwrapped = sslEngineUnwrap(dst); - if (bytesUnwrapped == 0) throw new SSLException("Got handshake data in application data unwrap"); - if (bytesUnwrapped == -1) return 0; - return bytesUnwrapped; - } - @Override public int write(ByteBuffer src) throws IOException { - FlushResult flushResult = flush(); - if (flushResult == FlushResult.NEED_WRITE) return 0; + if (flush() == FlushResult.NEED_WRITE) return 0; int totalBytesWrapped = 0; while (src.hasRemaining()) { - int bytesWrapped = sslEngineAppDataWrap(src); + int bytesWrapped = applicationDataWrap(src); if (bytesWrapped == 0) break; totalBytesWrapped += bytesWrapped; } @@ -165,49 +156,74 @@ public class TlsCryptoSocket implements CryptoSocket { return hasWrapBufferMoreData() ? FlushResult.NEED_WRITE : FlushResult.DONE; } - private int sslEngineAppDataWrap(ByteBuffer src) throws IOException { - int bytesWrapped = sslEngineWrap(src); - if (bytesWrapped == 0) throw new SSLException("Got handshake data in application data wrap"); - if (bytesWrapped == -1) return 0; - return bytesWrapped; + private boolean handshakeWrap() throws IOException { + SSLEngineResult result = sslEngineWrap(NULL_BUFFER); + switch (result.getStatus()) { + case OK: + return true; + case BUFFER_OVERFLOW: + return false; + default: + throw unexpectedStatusException(result.getStatus()); + } } - // returns number of bytes produced or -1 if unwrap buffer does not contain a full ssl frame - private int sslEngineUnwrap(ByteBuffer dst) throws IOException { - unwrapBuffer.flip(); - SSLEngineResult result = sslEngine.unwrap(unwrapBuffer, dst); - unwrapBuffer.compact(); - Status status = result.getStatus(); - switch (status) { + private int applicationDataWrap(ByteBuffer src) throws IOException { + SSLEngineResult result = sslEngineWrap(src); + switch (result.getStatus()) { case OK: - return result.bytesProduced(); + int bytesConsumed = result.bytesConsumed(); + if (bytesConsumed == 0) throw new SSLException("Got handshake data in application data wrap"); + return bytesConsumed; case BUFFER_OVERFLOW: - throw new SSLException("Cannot unwrap - remaining capacity too small: " + dst); - case BUFFER_UNDERFLOW: - return -1; - case CLOSED: - throw new ClosedChannelException(); + return 0; default: - throw new IllegalStateException("Unexpected status: " + status); + throw unexpectedStatusException(result.getStatus()); } } - // returns number of bytes consumed or -1 if wrap buffer remaining capacity is too small - private int sslEngineWrap(ByteBuffer src) throws IOException { + private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException { SSLEngineResult result = sslEngine.wrap(src, wrapBuffer); - Status status = result.getStatus(); - switch (status) { + if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); + return result; + } + + private boolean handshakeUnwrap() throws IOException { + SSLEngineResult result = sslEngineUnwrap(handshakeDummyBuffer); + switch (result.getStatus()) { case OK: - return result.bytesConsumed(); + if (result.bytesProduced() > 0) throw new SSLException("Got application data in handshake unwrap"); + return true; + case BUFFER_UNDERFLOW: + return false; + default: + throw unexpectedStatusException(result.getStatus()); + } + } + + private int applicationDataUnwrap(ByteBuffer dst) throws IOException { + SSLEngineResult result = sslEngineUnwrap(dst); + switch (result.getStatus()) { + case OK: + int bytesProduced = result.bytesProduced(); + if (bytesProduced == 0) throw new SSLException("Got handshake data in application data unwrap"); + return bytesProduced; case BUFFER_OVERFLOW: - return -1; - case CLOSED: - throw new ClosedChannelException(); + case BUFFER_UNDERFLOW: + return 0; default: - throw new IllegalStateException("Unexpected status: " + status); + throw unexpectedStatusException(result.getStatus()); } } + private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException { + unwrapBuffer.flip(); + SSLEngineResult result = sslEngine.unwrap(unwrapBuffer, dst); + unwrapBuffer.compact(); + if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); + return result; + } + // returns number of bytes read private int channelRead() throws IOException { int read = channel.read(unwrapBuffer); @@ -223,6 +239,14 @@ public class TlsCryptoSocket implements CryptoSocket { return written; } + private static IllegalStateException unhandledStateException(HandshakeState state) { + return new IllegalStateException("Unhandled state: " + state); + } + + private static IllegalStateException unexpectedStatusException(Status status) { + return new IllegalStateException("Unexpected status: " + status); + } + private void verifyHandshakeCompleted() throws SSLException { if (handshakeState != HandshakeState.COMPLETED) throw new SSLException("Handshake not completed: handshakeState=" + handshakeState); |