aboutsummaryrefslogtreecommitdiffstats
path: root/container-core/src/main/java/com/yahoo/jdisc/http/ssl/impl/ConfiguredSslContextFactoryProvider.java
blob: 8e2f080d4ced78c6b6e0f2b4dc264b72249a73d3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.http.ssl.impl;

import com.yahoo.jdisc.http.ConnectorConfig;
import com.yahoo.jdisc.http.SslProvider;
import com.yahoo.security.KeyUtils;
import com.yahoo.security.SslContextBuilder;
import com.yahoo.security.X509CertificateUtils;
import com.yahoo.security.AutoReloadingX509KeyManager;
import com.yahoo.security.tls.TlsContext;

import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/**
 * An implementation of {@link SslProvider} that uses the {@link ConnectorConfig} to configure SSL.
 *
 * @author bjorncs
 */
public class ConfiguredSslContextFactoryProvider implements SslProvider {

    private volatile AutoReloadingX509KeyManager keyManager;
    private final ConnectorConfig connectorConfig;

    public ConfiguredSslContextFactoryProvider(ConnectorConfig connectorConfig) {
        validateConfig(connectorConfig.ssl());
        this.connectorConfig = connectorConfig;
    }

    @Override
    public void configureSsl(ConnectorSsl ssl, String name, int port) {
        ConnectorConfig.Ssl sslConfig = connectorConfig.ssl();
        if (!sslConfig.enabled()) throw new IllegalStateException();

        SslContextBuilder builder = new SslContextBuilder();
        if (sslConfig.certificateFile().isBlank() || sslConfig.privateKeyFile().isBlank()) {
            PrivateKey privateKey = KeyUtils.fromPemEncodedPrivateKey(getPrivateKey(sslConfig));
            List<X509Certificate> certificates = X509CertificateUtils.certificateListFromPem(getCertificate(sslConfig));
            builder.withKeyStore(privateKey, certificates);
        } else {
            keyManager = AutoReloadingX509KeyManager.fromPemFiles(Paths.get(sslConfig.privateKeyFile()), Paths.get(sslConfig.certificateFile()));
            builder.withKeyManager(keyManager);
        }
        List<X509Certificate> caCertificates = getCaCertificates(sslConfig)
                .map(X509CertificateUtils::certificateListFromPem)
                .orElse(List.of());
        builder.withTrustStore(caCertificates);

        SSLContext sslContext = builder.build();

        ssl.setSslContext(sslContext);

        switch (sslConfig.clientAuth()) {
            case NEED_AUTH:
                ssl.setClientAuth(ConnectorSsl.ClientAuth.NEED);
                break;
            case WANT_AUTH:
                ssl.setClientAuth(ConnectorSsl.ClientAuth.WANT);
                break;
            case DISABLED:
                ssl.setClientAuth(ConnectorSsl.ClientAuth.DISABLED);
                break;
            default:
                throw new IllegalArgumentException(sslConfig.clientAuth().toString());
        }

        List<String> protocols = !sslConfig.enabledProtocols().isEmpty()
                ? sslConfig.enabledProtocols()
                : new ArrayList<>(TlsContext.getAllowedProtocols(sslContext));
        ssl.setEnabledProtocolVersions(protocols);

        List<String> ciphers = !sslConfig.enabledCipherSuites().isEmpty()
                ? sslConfig.enabledCipherSuites()
                : new ArrayList<>(TlsContext.getAllowedCipherSuites(sslContext));
        ssl.setEnabledCipherSuites(ciphers);
    }

    @Override
    public void close() {
        if (keyManager != null) {
            keyManager.close();
        }
    }

    private static void validateConfig(ConnectorConfig.Ssl config) {
        if (!config.enabled()) return;

        if(hasBoth(config.certificate(), config.certificateFile()))
            throw new IllegalArgumentException("Specified both certificate and certificate file.");

        if(hasBoth(config.privateKey(), config.privateKeyFile()))
            throw new IllegalArgumentException("Specified both private key and private key file.");

        if(hasNeither(config.certificate(), config.certificateFile()))
            throw new IllegalArgumentException("Specified neither certificate or certificate file.");

        if(hasNeither(config.privateKey(), config.privateKeyFile()))
            throw new IllegalArgumentException("Specified neither private key or private key file.");
    }

    private static boolean hasBoth(String a, String b) { return !a.isBlank() && !b.isBlank(); }
    private static boolean hasNeither(String a, String b) { return a.isBlank() && b.isBlank(); }

    Optional<String> getCaCertificates(ConnectorConfig.Ssl sslConfig) {
        var sb = new StringBuilder();
        if (sslConfig.caCertificateFile().isBlank() && sslConfig.caCertificate().isBlank()) return Optional.empty();
        if (!sslConfig.caCertificate().isBlank()) {
            sb.append(sslConfig.caCertificate());
        }
        if (!sslConfig.caCertificateFile().isBlank()) {
            if (sb.length() > 0) sb.append('\n');
            sb.append(readToString(sslConfig.caCertificateFile()));
        }
        return Optional.of(sb.toString());
    }

    private static String getPrivateKey(ConnectorConfig.Ssl config) {
        if(!config.privateKey().isBlank()) return config.privateKey();
        return readToString(config.privateKeyFile());
    }

    private static String getCertificate(ConnectorConfig.Ssl config) {
        if(!config.certificate().isBlank()) return config.certificate();
        return readToString(config.certificateFile());
    }

    static String readToString(String filename) {
        try {
            return Files.readString(Paths.get(filename), StandardCharsets.UTF_8);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

}