summaryrefslogtreecommitdiffstats
path: root/jrt
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@oath.com>2018-09-04 12:43:05 +0200
committerBjørn Christian Seime <bjorncs@oath.com>2018-09-04 12:43:05 +0200
commit84aa3d47006366a91f6c8a604bf9c19718b43c66 (patch)
tree4654a1e0668a8d69b9b54f1397ec9e7fa4e963db /jrt
parentf7ab4bc80b354d5db31119772a801034d72e0ec6 (diff)
Rewrite wrap+unwrap to remove use of magic return values
- Wrap/unwrap for handshake returns true for success, false otherwise - Wrap/unwrap for application data returns bytes consumed/produced - Do not throw exception on overflow for unwrap - Misc changes to reduce code duplication
Diffstat (limited to 'jrt')
-rw-r--r--jrt/src/com/yahoo/jrt/TlsCryptoSocket.java118
1 files changed, 71 insertions, 47 deletions
diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java
index 36d470bf67b..eaa28b37aab 100644
--- a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java
+++ b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java
@@ -20,6 +20,8 @@ import static javax.net.ssl.SSLEngineResult.*;
*/
public class TlsCryptoSocket implements CryptoSocket {
+ private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0);
+
private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName());
private enum HandshakeState { NOT_STARTED, NEED_READ, NEED_WRITE, COMPLETED }
@@ -69,7 +71,7 @@ public class TlsCryptoSocket implements CryptoSocket {
case COMPLETED:
return HandshakeState.COMPLETED;
default:
- throw new IllegalStateException("Unhandled state: " + state);
+ throw unhandledStateException(state);
}
while (true) {
@@ -84,13 +86,10 @@ public class TlsCryptoSocket implements CryptoSocket {
break;
case NEED_UNWRAP:
if (hasWrapBufferMoreData()) return HandshakeState.NEED_WRITE;
- int bytesUnwrapped = sslEngineUnwrap(handshakeDummyBuffer);
- if (handshakeDummyBuffer.position() > 0) throw new SSLException("Got application data in handshake unwrap: " + handshakeDummyBuffer);
- if (bytesUnwrapped == -1) return HandshakeState.NEED_READ;
+ if (!handshakeUnwrap()) return HandshakeState.NEED_READ;
break;
case NEED_WRAP:
- int bytesWrapped = sslEngineWrap(handshakeDummyBuffer);
- if (bytesWrapped == -1) return HandshakeState.NEED_WRITE;
+ if (!handshakeWrap()) return HandshakeState.NEED_WRITE;
break;
default:
throw new IllegalStateException("Unexpected handshake status: " + sslEngine.getHandshakeStatus());
@@ -107,7 +106,7 @@ public class TlsCryptoSocket implements CryptoSocket {
case COMPLETED:
return HandshakeResult.DONE;
default:
- throw new IllegalStateException("Unhandled state: " + state);
+ throw unhandledStateException(state);
}
}
@@ -119,7 +118,7 @@ public class TlsCryptoSocket implements CryptoSocket {
@Override
public int read(ByteBuffer dst) throws IOException {
verifyHandshakeCompleted();
- int bytesUnwrapped = sslEngineAppDataUnwrap(dst);
+ int bytesUnwrapped = applicationDataUnwrap(dst);
if (bytesUnwrapped > 0) return bytesUnwrapped;
int bytesRead = channelRead();
@@ -133,26 +132,18 @@ public class TlsCryptoSocket implements CryptoSocket {
int totalBytesUnwrapped = 0;
int bytesUnwrapped;
do {
- bytesUnwrapped = sslEngineAppDataUnwrap(dst);
+ bytesUnwrapped = applicationDataUnwrap(dst);
totalBytesUnwrapped += bytesUnwrapped;
} while (bytesUnwrapped > 0);
return totalBytesUnwrapped;
}
- private int sslEngineAppDataUnwrap(ByteBuffer dst) throws IOException {
- int bytesUnwrapped = sslEngineUnwrap(dst);
- if (bytesUnwrapped == 0) throw new SSLException("Got handshake data in application data unwrap");
- if (bytesUnwrapped == -1) return 0;
- return bytesUnwrapped;
- }
-
@Override
public int write(ByteBuffer src) throws IOException {
- FlushResult flushResult = flush();
- if (flushResult == FlushResult.NEED_WRITE) return 0;
+ if (flush() == FlushResult.NEED_WRITE) return 0;
int totalBytesWrapped = 0;
while (src.hasRemaining()) {
- int bytesWrapped = sslEngineAppDataWrap(src);
+ int bytesWrapped = applicationDataWrap(src);
if (bytesWrapped == 0) break;
totalBytesWrapped += bytesWrapped;
}
@@ -165,49 +156,74 @@ public class TlsCryptoSocket implements CryptoSocket {
return hasWrapBufferMoreData() ? FlushResult.NEED_WRITE : FlushResult.DONE;
}
- private int sslEngineAppDataWrap(ByteBuffer src) throws IOException {
- int bytesWrapped = sslEngineWrap(src);
- if (bytesWrapped == 0) throw new SSLException("Got handshake data in application data wrap");
- if (bytesWrapped == -1) return 0;
- return bytesWrapped;
+ private boolean handshakeWrap() throws IOException {
+ SSLEngineResult result = sslEngineWrap(NULL_BUFFER);
+ switch (result.getStatus()) {
+ case OK:
+ return true;
+ case BUFFER_OVERFLOW:
+ return false;
+ default:
+ throw unexpectedStatusException(result.getStatus());
+ }
}
- // returns number of bytes produced or -1 if unwrap buffer does not contain a full ssl frame
- private int sslEngineUnwrap(ByteBuffer dst) throws IOException {
- unwrapBuffer.flip();
- SSLEngineResult result = sslEngine.unwrap(unwrapBuffer, dst);
- unwrapBuffer.compact();
- Status status = result.getStatus();
- switch (status) {
+ private int applicationDataWrap(ByteBuffer src) throws IOException {
+ SSLEngineResult result = sslEngineWrap(src);
+ switch (result.getStatus()) {
case OK:
- return result.bytesProduced();
+ int bytesConsumed = result.bytesConsumed();
+ if (bytesConsumed == 0) throw new SSLException("Got handshake data in application data wrap");
+ return bytesConsumed;
case BUFFER_OVERFLOW:
- throw new SSLException("Cannot unwrap - remaining capacity too small: " + dst);
- case BUFFER_UNDERFLOW:
- return -1;
- case CLOSED:
- throw new ClosedChannelException();
+ return 0;
default:
- throw new IllegalStateException("Unexpected status: " + status);
+ throw unexpectedStatusException(result.getStatus());
}
}
- // returns number of bytes consumed or -1 if wrap buffer remaining capacity is too small
- private int sslEngineWrap(ByteBuffer src) throws IOException {
+ private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException {
SSLEngineResult result = sslEngine.wrap(src, wrapBuffer);
- Status status = result.getStatus();
- switch (status) {
+ if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException();
+ return result;
+ }
+
+ private boolean handshakeUnwrap() throws IOException {
+ SSLEngineResult result = sslEngineUnwrap(handshakeDummyBuffer);
+ switch (result.getStatus()) {
case OK:
- return result.bytesConsumed();
+ if (result.bytesProduced() > 0) throw new SSLException("Got application data in handshake unwrap");
+ return true;
+ case BUFFER_UNDERFLOW:
+ return false;
+ default:
+ throw unexpectedStatusException(result.getStatus());
+ }
+ }
+
+ private int applicationDataUnwrap(ByteBuffer dst) throws IOException {
+ SSLEngineResult result = sslEngineUnwrap(dst);
+ switch (result.getStatus()) {
+ case OK:
+ int bytesProduced = result.bytesProduced();
+ if (bytesProduced == 0) throw new SSLException("Got handshake data in application data unwrap");
+ return bytesProduced;
case BUFFER_OVERFLOW:
- return -1;
- case CLOSED:
- throw new ClosedChannelException();
+ case BUFFER_UNDERFLOW:
+ return 0;
default:
- throw new IllegalStateException("Unexpected status: " + status);
+ throw unexpectedStatusException(result.getStatus());
}
}
+ private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException {
+ unwrapBuffer.flip();
+ SSLEngineResult result = sslEngine.unwrap(unwrapBuffer, dst);
+ unwrapBuffer.compact();
+ if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException();
+ return result;
+ }
+
// returns number of bytes read
private int channelRead() throws IOException {
int read = channel.read(unwrapBuffer);
@@ -223,6 +239,14 @@ public class TlsCryptoSocket implements CryptoSocket {
return written;
}
+ private static IllegalStateException unhandledStateException(HandshakeState state) {
+ return new IllegalStateException("Unhandled state: " + state);
+ }
+
+ private static IllegalStateException unexpectedStatusException(Status status) {
+ return new IllegalStateException("Unexpected status: " + status);
+ }
+
private void verifyHandshakeCompleted() throws SSLException {
if (handshakeState != HandshakeState.COMPLETED)
throw new SSLException("Handshake not completed: handshakeState=" + handshakeState);