aboutsummaryrefslogtreecommitdiffstats
path: root/container-core/src/main/java/com/yahoo/jdisc/http/ssl/impl/ConfiguredSslContextFactoryProvider.java
blob: 90848f1dfd4a6cd3d540f84ec95318f129d96604 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.http.ssl.impl;

import com.yahoo.jdisc.http.ConnectorConfig;
import com.yahoo.jdisc.http.ConnectorConfig.Ssl.ClientAuth;
import com.yahoo.jdisc.http.ssl.SslContextFactoryProvider;
import com.yahoo.security.KeyUtils;
import com.yahoo.security.SslContextBuilder;
import com.yahoo.security.X509CertificateUtils;
import com.yahoo.security.tls.AutoReloadingX509KeyManager;
import com.yahoo.security.tls.TlsContext;
import org.eclipse.jetty.util.ssl.SslContextFactory;

import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static com.yahoo.jdisc.http.ssl.impl.SslContextFactoryUtils.setEnabledCipherSuites;
import static com.yahoo.jdisc.http.ssl.impl.SslContextFactoryUtils.setEnabledProtocols;

/**
 * An implementation of {@link SslContextFactoryProvider} that uses the {@link ConnectorConfig} to construct a {@link SslContextFactory}.
 *
 * @author bjorncs
 */
public class ConfiguredSslContextFactoryProvider implements SslContextFactoryProvider {

    private volatile AutoReloadingX509KeyManager keyManager;
    private final ConnectorConfig connectorConfig;

    public ConfiguredSslContextFactoryProvider(ConnectorConfig connectorConfig) {
        validateConfig(connectorConfig.ssl());
        this.connectorConfig = connectorConfig;
    }

    @Override
    public SslContextFactory getInstance(String containerId, int port) {
        ConnectorConfig.Ssl sslConfig = connectorConfig.ssl();
        if (!sslConfig.enabled()) throw new IllegalStateException();

        SslContextBuilder builder = new SslContextBuilder();
        if (sslConfig.certificateFile().isBlank() || sslConfig.privateKeyFile().isBlank()) {
            PrivateKey privateKey = KeyUtils.fromPemEncodedPrivateKey(getPrivateKey(sslConfig));
            List<X509Certificate> certificates = X509CertificateUtils.certificateListFromPem(getCertificate(sslConfig));
            builder.withKeyStore(privateKey, certificates);
        } else {
            keyManager = AutoReloadingX509KeyManager.fromPemFiles(Paths.get(sslConfig.privateKeyFile()), Paths.get(sslConfig.certificateFile()));
            builder.withKeyManager(keyManager);
        }
        List<X509Certificate> caCertificates = getCaCertificates(sslConfig)
                .map(X509CertificateUtils::certificateListFromPem)
                .orElse(List.of());
        builder.withTrustStore(caCertificates);

        SSLContext sslContext = builder.build();

        SslContextFactory.Server factory = new SslContextFactory.Server();
        factory.setSslContext(sslContext);

        factory.setNeedClientAuth(sslConfig.clientAuth() == ClientAuth.Enum.NEED_AUTH);
        factory.setWantClientAuth(sslConfig.clientAuth() == ClientAuth.Enum.WANT_AUTH);

        List<String> protocols = !sslConfig.enabledProtocols().isEmpty()
                ? sslConfig.enabledProtocols()
                : new ArrayList<>(TlsContext.getAllowedProtocols(sslContext));
        setEnabledProtocols(factory, sslContext, protocols);

        List<String> ciphers = !sslConfig.enabledCipherSuites().isEmpty()
                ? sslConfig.enabledCipherSuites()
                : new ArrayList<>(TlsContext.getAllowedCipherSuites(sslContext));
        setEnabledCipherSuites(factory, sslContext, ciphers);

        return factory;
    }

    @Override
    public void close() {
        if (keyManager != null) {
            keyManager.close();
        }
    }

    private static void validateConfig(ConnectorConfig.Ssl config) {
        if (!config.enabled()) return;

        if(hasBoth(config.certificate(), config.certificateFile()))
            throw new IllegalArgumentException("Specified both certificate and certificate file.");

        if(hasBoth(config.privateKey(), config.privateKeyFile()))
            throw new IllegalArgumentException("Specified both private key and private key file.");

        if(hasNeither(config.certificate(), config.certificateFile()))
            throw new IllegalArgumentException("Specified neither certificate or certificate file.");

        if(hasNeither(config.privateKey(), config.privateKeyFile()))
            throw new IllegalArgumentException("Specified neither private key or private key file.");
    }

    private static boolean hasBoth(String a, String b) { return !a.isBlank() && !b.isBlank(); }
    private static boolean hasNeither(String a, String b) { return a.isBlank() && b.isBlank(); }

    private static Optional<String> getCaCertificates(ConnectorConfig.Ssl sslConfig) {
        if (!sslConfig.caCertificate().isBlank()) {
            return Optional.of(sslConfig.caCertificate());
        } else if (!sslConfig.caCertificateFile().isBlank()) {
            return Optional.of(readToString(sslConfig.caCertificateFile()));
        } else {
            return Optional.empty();
        }
    }

    private static String getPrivateKey(ConnectorConfig.Ssl config) {
        if(!config.privateKey().isBlank()) return config.privateKey();
        return readToString(config.privateKeyFile());
    }

    private static String getCertificate(ConnectorConfig.Ssl config) {
        if(!config.certificate().isBlank()) return config.certificate();
        return readToString(config.certificateFile());
    }

    private static String readToString(String filename) {
        try {
            return Files.readString(Paths.get(filename), StandardCharsets.UTF_8);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

}