summaryrefslogtreecommitdiffstats
path: root/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java
blob: a01283318b61f5e8041e6755b4e39ba093fdf25a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.security.tls;

import com.yahoo.security.SslContextBuilder;
import com.yahoo.security.tls.authz.PeerAuthorizerTrustManager;
import com.yahoo.security.tls.policy.AuthorizedPeers;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * A static {@link TlsContext}
 *
 * @author bjorncs
 */
public class DefaultTlsContext implements TlsContext {

    private static final Logger log = Logger.getLogger(DefaultTlsContext.class.getName());

    private final SSLContext sslContext;
    private final String[] validCiphers;
    private final String[] validProtocols;
    private final PeerAuthentication peerAuthentication;

    public DefaultTlsContext(List<X509Certificate> certificates,
                             PrivateKey privateKey,
                             List<X509Certificate> caCertificates,
                             AuthorizedPeers authorizedPeers,
                             AuthorizationMode mode,
                             PeerAuthentication peerAuthentication,
                             HostnameVerification hostnameVerification) {
        this(createSslContext(certificates, privateKey, caCertificates, authorizedPeers, mode, hostnameVerification), peerAuthentication);
    }

    public DefaultTlsContext(SSLContext sslContext, PeerAuthentication peerAuthentication) {
        this(sslContext, TlsContext.ALLOWED_CIPHER_SUITES, TlsContext.ALLOWED_PROTOCOLS, peerAuthentication);
    }

    DefaultTlsContext(SSLContext sslContext, Set<String> acceptedCiphers, Set<String> acceptedProtocols, PeerAuthentication peerAuthentication) {
        this.sslContext = sslContext;
        this.peerAuthentication = peerAuthentication;
        this.validCiphers = getAllowedCiphers(sslContext, acceptedCiphers);
        this.validProtocols = getAllowedProtocols(sslContext, acceptedProtocols);
    }

    private static String[] getAllowedCiphers(SSLContext sslContext, Set<String> acceptedCiphers) {
        Set<String> supportedCiphers = TlsContext.getAllowedCipherSuites(sslContext);
        String[] allowedCiphers = supportedCiphers.stream()
                .filter(acceptedCiphers::contains)
                .toArray(String[]::new);
        if (allowedCiphers.length == 0) {
            throw new IllegalStateException(
                    String.format("None of the accepted ciphers are supported (supported=%s, accepted=%s)",
                                  supportedCiphers, acceptedCiphers));
        }
        log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", Arrays.asList(allowedCiphers)));
        return allowedCiphers;
    }

    private static String[] getAllowedProtocols(SSLContext sslContext, Set<String> acceptedProtocols) {
        Set<String> supportedProtocols = TlsContext.getAllowedProtocols(sslContext);
        String[] allowedProtocols = supportedProtocols.stream()
                .filter(acceptedProtocols::contains)
                .toArray(String[]::new);
        if (allowedProtocols.length == 0) {
            throw new IllegalStateException(
                    String.format("None of the accepted protocols are supported (supported=%s, accepted=%s)",
                            supportedProtocols, acceptedProtocols));
        }
        log.log(Level.FINE, () -> String.format("Allowed protocols that are supported: %s", Arrays.toString(allowedProtocols)));
        return allowedProtocols;
    }

    @Override
    public SSLContext context() {
        return sslContext;
    }

    @Override
    public SSLParameters parameters() {
        return createSslParameters();
    }

    @Override
    public SSLEngine createSslEngine() {
        SSLEngine sslEngine = sslContext.createSSLEngine();
        sslEngine.setSSLParameters(createSslParameters());
        return sslEngine;
    }

    @Override
    public SSLEngine createSslEngine(String peerHost, int peerPort) {
        SSLEngine sslEngine = sslContext.createSSLEngine(peerHost, peerPort);
        sslEngine.setSSLParameters(createSslParameters());
        return sslEngine;
    }

    private SSLParameters createSslParameters() {
        SSLParameters newParameters = sslContext.getDefaultSSLParameters();
        newParameters.setCipherSuites(validCiphers);
        newParameters.setProtocols(validProtocols);
        switch (peerAuthentication) {
            case WANT:
                newParameters.setWantClientAuth(true);
                break;
            case NEED:
                newParameters.setNeedClientAuth(true);
                break;
            case DISABLED:
                break;
            default:
                throw new UnsupportedOperationException("Unknown peer authentication: " + peerAuthentication);
        }
        return newParameters;
    }

    private static SSLContext createSslContext(List<X509Certificate> certificates,
                                               PrivateKey privateKey,
                                               List<X509Certificate> caCertificates,
                                               AuthorizedPeers authorizedPeers,
                                               AuthorizationMode mode,
                                               HostnameVerification hostnameVerification) {
        SslContextBuilder builder = new SslContextBuilder();
        if (!certificates.isEmpty()) {
            builder.withKeyStore(privateKey, certificates);
        }
        if (!caCertificates.isEmpty()) {
            builder.withTrustStore(caCertificates);
        }
        if (authorizedPeers != null) {
            builder.withTrustManagerFactory(truststore -> new PeerAuthorizerTrustManager(authorizedPeers, mode, hostnameVerification, truststore));
        } else {
            builder.withTrustManagerFactory(truststore -> new PeerAuthorizerTrustManager(
                    new AuthorizedPeers(Collections.emptySet()), AuthorizationMode.DISABLE, hostnameVerification, truststore));
        }
        return builder.build();
    }

}