summaryrefslogtreecommitdiffstats
path: root/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java
blob: b2edf2f1ebcac336e6cace17ef812d854b3625bc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.security.tls;

import com.yahoo.security.SslContextBuilder;
import com.yahoo.security.tls.authz.PeerAuthorizerTrustManager;
import com.yahoo.security.tls.policy.AuthorizedPeers;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * A static {@link TlsContext}
 *
 * @author bjorncs
 */
public class DefaultTlsContext implements TlsContext {

    private static final Logger log = Logger.getLogger(DefaultTlsContext.class.getName());

    private final SSLContext sslContext;
    private final String[] validCiphers;
    private final String[] validProtocols;

    public DefaultTlsContext(List<X509Certificate> certificates,
                             PrivateKey privateKey,
                             List<X509Certificate> caCertificates,
                             AuthorizedPeers authorizedPeers,
                             AuthorizationMode mode) {
        this(createSslContext(certificates, privateKey, caCertificates, authorizedPeers, mode));
    }


    public DefaultTlsContext(SSLContext sslContext) {
        this(sslContext, TlsContext.ALLOWED_CIPHER_SUITES);
    }

    DefaultTlsContext(SSLContext sslContext, Set<String> acceptedCiphers) {
        this.sslContext = sslContext;
        this.validCiphers = getAllowedCiphers(sslContext, acceptedCiphers);
        this.validProtocols = getAllowedProtocols(sslContext);
    }


    private static String[] getAllowedCiphers(SSLContext sslContext, Set<String> acceptedCiphers) {
        String[] supportedCipherSuites = sslContext.getSupportedSSLParameters().getCipherSuites();
        String[] validCipherSuites = Arrays.stream(supportedCipherSuites)
                .filter(suite -> ALLOWED_CIPHER_SUITES.contains(suite) && acceptedCiphers.contains(suite))
                .toArray(String[]::new);
        if (validCipherSuites.length == 0) {
            throw new IllegalStateException(
                    String.format("None of the allowed cipher suites are supported " +
                                          "(allowed-cipher-suites=%s, supported-cipher-suites=%s, accepted-cipher-suites=%s)",
                                  ALLOWED_CIPHER_SUITES, List.of(supportedCipherSuites), acceptedCiphers));
        }
        log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", List.of(validCipherSuites)));
        return validCipherSuites;
    }

    private static String[] getAllowedProtocols(SSLContext sslContext) {
        String[] supportedProtocols = sslContext.getSupportedSSLParameters().getProtocols();
        String[] validProtocols = Arrays.stream(supportedProtocols)
                .filter(ALLOWED_PROTOCOLS::contains)
                .toArray(String[]::new);
        if (validProtocols.length == 0) {
            throw new IllegalArgumentException(
                    String.format("None of the allowed protocols are supported (allowed-protocols=%s, supported-protocols=%s)",
                                  ALLOWED_PROTOCOLS, List.of(supportedProtocols)));
        }
        log.log(Level.FINE, () -> String.format("Allowed protocols that are supported: %s", List.of(validProtocols)));
        return validProtocols;
    }

    @Override
    public SSLContext context() {
        return sslContext;
    }

    @Override
    public SSLParameters parameters() {
        return createSslParameters();
    }

    @Override
    public SSLEngine createSslEngine() {
        SSLEngine sslEngine = sslContext.createSSLEngine();
        sslEngine.setSSLParameters(createSslParameters());
        return sslEngine;
    }

    @Override
    public SSLEngine createSslEngine(String peerHost, int peerPort) {
        SSLEngine sslEngine = sslContext.createSSLEngine(peerHost, peerPort);
        sslEngine.setSSLParameters(createSslParameters());
        return sslEngine;
    }

    private SSLParameters createSslParameters() {
        SSLParameters newParameters = sslContext.getDefaultSSLParameters();
        newParameters.setCipherSuites(validCiphers);
        newParameters.setProtocols(validProtocols);
        newParameters.setNeedClientAuth(true);
        return newParameters;
    }

    private static SSLContext createSslContext(List<X509Certificate> certificates,
                                               PrivateKey privateKey,
                                               List<X509Certificate> caCertificates,
                                               AuthorizedPeers authorizedPeers,
                                               AuthorizationMode mode) {
        SslContextBuilder builder = new SslContextBuilder();
        if (!certificates.isEmpty()) {
            builder.withKeyStore(privateKey, certificates);
        }
        if (!caCertificates.isEmpty()) {
            builder.withTrustStore(caCertificates);
        }
        if (authorizedPeers != null) {
            builder.withTrustManagerFactory(truststore -> new PeerAuthorizerTrustManager(authorizedPeers, mode, truststore));
        } else {
            builder.withTrustManagerFactory(truststore -> new PeerAuthorizerTrustManager(new AuthorizedPeers(Set.of()), AuthorizationMode.DISABLE, truststore));
        }
        return builder.build();
    }


}