// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jrt; import com.yahoo.security.tls.ConnectionAuthContext; import com.yahoo.security.tls.PeerAuthorizationFailedException; import com.yahoo.security.tls.TransportSecurityUtils; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSession; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; import java.util.Objects; import java.util.logging.Logger; import static javax.net.ssl.SSLEngineResult.Status; /** * A {@link CryptoSocket} using TLS ({@link SSLEngine}) * * @author bjorncs */ public class TlsCryptoSocket implements CryptoSocket { private static final ByteBuffer NULL_BUFFER = ByteBuffer.allocate(0); private static final Logger log = Logger.getLogger(TlsCryptoSocket.class.getName()); private enum HandshakeState { NOT_STARTED, NEED_READ, NEED_WRITE, NEED_WORK, COMPLETED } private final TransportMetrics metrics = TransportMetrics.getInstance(); private final SocketChannel channel; private final SSLEngine sslEngine; private final Buffer wrapBuffer; private final Buffer unwrapBuffer; private int sessionPacketBufferSize; private int sessionApplicationBufferSize; private ByteBuffer handshakeDummyBuffer; private HandshakeState handshakeState; private ConnectionAuthContext authContext; public TlsCryptoSocket(SocketChannel channel, SSLEngine sslEngine) { this.channel = channel; this.sslEngine = sslEngine; this.wrapBuffer = new Buffer(0); this.unwrapBuffer = new Buffer(0); SSLSession nullSession = sslEngine.getSession(); sessionApplicationBufferSize = nullSession.getApplicationBufferSize(); sessionPacketBufferSize = nullSession.getPacketBufferSize(); // Note: Dummy buffer as unwrap requires a full size application buffer even though no application data is unwrapped this.handshakeDummyBuffer = ByteBuffer.allocate(sessionApplicationBufferSize); this.handshakeState = HandshakeState.NOT_STARTED; log.fine(() -> "Initialized with " + sslEngine.toString()); } // inject pre-read data into the read pipeline (typically called by MaybeTlsCryptoSocket) public void injectReadData(Buffer data) { unwrapBuffer.getWritable(data.bytes()).put(data.getReadable()); } @Override public SocketChannel channel() { return channel; } @Override public HandshakeResult handshake() throws IOException { HandshakeState newHandshakeState = processHandshakeState(this.handshakeState); log.fine(() -> String.format("Handshake state '%s -> %s'", this.handshakeState, newHandshakeState)); this.handshakeState = newHandshakeState; return toHandshakeResult(newHandshakeState); } @Override public void doHandshakeWork() { Runnable task; while ((task = sslEngine.getDelegatedTask()) != null) { task.run(); } } private HandshakeState processHandshakeState(HandshakeState state) throws IOException { try { switch (state) { case NOT_STARTED: log.fine(() -> "Initiating handshake"); sslEngine.beginHandshake(); break; case NEED_WRITE: channelWrite(); break; case NEED_READ: channelRead(); break; case NEED_WORK: break; case COMPLETED: return HandshakeState.COMPLETED; default: throw unhandledStateException(state); } while (true) { log.fine(() -> "SSLEngine.getHandshakeStatus(): " + sslEngine.getHandshakeStatus()); switch (sslEngine.getHandshakeStatus()) { case NOT_HANDSHAKING: if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE; sslEngine.setEnableSessionCreation(false); // disable renegotiation handshakeDummyBuffer = null; SSLSession session = sslEngine.getSession(); sessionApplicationBufferSize = session.getApplicationBufferSize(); sessionPacketBufferSize = session.getPacketBufferSize(); authContext = TransportSecurityUtils.getConnectionAuthContext(session).orElseThrow(); if (!authContext.authorized()) { metrics.incrementPeerAuthorizationFailures(); } log.fine(() -> String.format("Handshake complete: protocol=%s, cipherSuite=%s", session.getProtocol(), session.getCipherSuite())); if (sslEngine.getUseClientMode()) { metrics.incrementClientTlsConnectionsEstablished(); } else { metrics.incrementServerTlsConnectionsEstablished(); } return HandshakeState.COMPLETED; case NEED_TASK: return HandshakeState.NEED_WORK; case NEED_UNWRAP: if (wrapBuffer.bytes() > 0) return HandshakeState.NEED_WRITE; if (!handshakeUnwrap()) return HandshakeState.NEED_READ; break; case NEED_WRAP: if (!handshakeWrap()) return HandshakeState.NEED_WRITE; break; default: throw new IllegalStateException("Unexpected handshake status: " + sslEngine.getHandshakeStatus()); } } } catch (SSLHandshakeException e) { if (!(e.getCause() instanceof PeerAuthorizationFailedException)) { metrics.incrementTlsCertificateVerificationFailures(); } throw e; } } private static HandshakeResult toHandshakeResult(HandshakeState state) { switch (state) { case NEED_READ: return HandshakeResult.NEED_READ; case NEED_WRITE: return HandshakeResult.NEED_WRITE; case NEED_WORK: return HandshakeResult.NEED_WORK; case COMPLETED: return HandshakeResult.DONE; default: throw unhandledStateException(state); } } @Override public int getMinimumReadBufferSize() { return sessionApplicationBufferSize; } @Override public int read(ByteBuffer dst) throws IOException { verifyHandshakeCompleted(); int bytesUnwrapped = drain(dst); if (bytesUnwrapped > 0) return bytesUnwrapped; int bytesRead = channelRead(); if (bytesRead == 0) return 0; return drain(dst); } @Override public int drain(ByteBuffer dst) throws IOException { verifyHandshakeCompleted(); int totalBytesUnwrapped = 0; while (true) { int result = applicationDataUnwrap(dst); if (result < 0) return totalBytesUnwrapped; totalBytesUnwrapped += result; } } @Override public int write(ByteBuffer src) throws IOException { verifyHandshakeCompleted(); if (flush() == FlushResult.NEED_WRITE) return 0; int totalBytesWrapped = 0; int bytesWrapped; do { bytesWrapped = applicationDataWrap(src); totalBytesWrapped += bytesWrapped; } while (bytesWrapped > 0 && wrapBuffer.bytes() < sessionPacketBufferSize); return totalBytesWrapped; } @Override public FlushResult flush() throws IOException { verifyHandshakeCompleted(); channelWrite(); return wrapBuffer.bytes() > 0 ? FlushResult.NEED_WRITE : FlushResult.DONE; } @Override public void dropEmptyBuffers() { wrapBuffer.shrink(0); unwrapBuffer.shrink(0); } @Override public ConnectionAuthContext connectionAuthContext() { if (handshakeState != HandshakeState.COMPLETED) throw new IllegalStateException("Handshake not complete"); return Objects.requireNonNull(authContext); } private boolean handshakeWrap() throws IOException { SSLEngineResult result = sslEngineWrap(NULL_BUFFER); switch (result.getStatus()) { case OK: return true; case BUFFER_OVERFLOW: // This is to ensure we have large enough buffer during handshake phase too. sessionPacketBufferSize = sslEngine.getSession().getPacketBufferSize(); return false; default: throw unexpectedStatusException(result.getStatus()); } } private int applicationDataWrap(ByteBuffer src) throws IOException { SSLEngineResult result = sslEngineWrap(src); failIfRenegotiationDetected(result); switch (result.getStatus()) { case OK: return result.bytesConsumed(); case BUFFER_OVERFLOW: return 0; default: throw unexpectedStatusException(result.getStatus()); } } private SSLEngineResult sslEngineWrap(ByteBuffer src) throws IOException { SSLEngineResult result = sslEngine.wrap(src, wrapBuffer.getWritable(sessionPacketBufferSize)); failIfCloseSignalDetected(result); return result; } private boolean handshakeUnwrap() throws IOException { SSLEngineResult result = sslEngineUnwrap(handshakeDummyBuffer); switch (result.getStatus()) { case OK: if (result.bytesProduced() > 0) throw new SSLException("Got application data in handshake unwrap"); return true; case BUFFER_UNDERFLOW: return false; default: throw unexpectedStatusException(result.getStatus()); } } private int applicationDataUnwrap(ByteBuffer dst) throws IOException { SSLEngineResult result = sslEngineUnwrap(dst); failIfRenegotiationDetected(result); switch (result.getStatus()) { case OK: return result.bytesProduced(); case BUFFER_OVERFLOW: case BUFFER_UNDERFLOW: return -1; default: throw unexpectedStatusException(result.getStatus()); } } private SSLEngineResult sslEngineUnwrap(ByteBuffer dst) throws IOException { SSLEngineResult result = sslEngine.unwrap(unwrapBuffer.getReadable(), dst); failIfCloseSignalDetected(result); return result; } // returns number of bytes read private int channelRead() throws IOException { int read = channel.read(unwrapBuffer.getWritable(sessionPacketBufferSize)); if (read == -1) throw new ClosedChannelException(); return read; } // returns number of bytes written private int channelWrite() throws IOException { return channel.write(wrapBuffer.getReadable()); } private static void failIfCloseSignalDetected(SSLEngineResult result) throws ClosedChannelException { if (result.getStatus() == Status.CLOSED) throw new ClosedChannelException(); } private static void failIfRenegotiationDetected(SSLEngineResult result) throws SSLException { if (result.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && result.getHandshakeStatus() != HandshakeStatus.FINISHED) { throw new SSLException("Renegotiation detected"); } } private static IllegalStateException unhandledStateException(HandshakeState state) { return new IllegalStateException("Unhandled state: " + state); } private static IllegalStateException unexpectedStatusException(Status status) { return new IllegalStateException("Unexpected status: " + status); } private void verifyHandshakeCompleted() throws SSLException { if (handshakeState != HandshakeState.COMPLETED) throw new SSLException("Handshake not completed: handshakeState=" + handshakeState); } }