aboutsummaryrefslogtreecommitdiffstats
path: root/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java
blob: 699aa5c9187eff3f80acde9e2798e7c03f13d40b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.http.filter.security.cloud;

import com.yahoo.component.annotation.Inject;
import com.yahoo.container.logging.AccessLogEntry;
import com.yahoo.jdisc.Response;
import com.yahoo.jdisc.http.filter.DiscFilterRequest;
import com.yahoo.jdisc.http.filter.security.base.JsonSecurityRequestFilterBase;
import com.yahoo.jdisc.http.filter.security.cloud.config.CloudTokenDataPlaneFilterConfig;
import com.yahoo.security.token.Token;
import com.yahoo.security.token.TokenCheckHash;
import com.yahoo.security.token.TokenDomain;
import com.yahoo.security.token.TokenFingerprint;

import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;

import static com.yahoo.jdisc.http.server.jetty.AccessLoggingRequestHandler.CONTEXT_KEY_ACCESS_LOG_ENTRY;

/**
 * Token data plane filter for Cloud
 *
 * @author bjorncs
 */
public class CloudTokenDataPlaneFilter extends JsonSecurityRequestFilterBase {

    private static final Logger log = Logger.getLogger(CloudTokenDataPlaneFilter.class.getName());
    static final int CHECK_HASH_BYTES = 32;

    private final List<Client> allowedClients;
    private final TokenDomain tokenDomain;
    private final Clock clock;

    @Inject
    public CloudTokenDataPlaneFilter(CloudTokenDataPlaneFilterConfig cfg) {
        this(cfg, Clock.systemUTC());
    }

    CloudTokenDataPlaneFilter(CloudTokenDataPlaneFilterConfig cfg, Clock clock) {
        this.tokenDomain = TokenDomain.of(cfg.tokenContext());
        this.clock = clock;
        this.allowedClients = parseClients(cfg);
    }

    private static List<Client> parseClients(CloudTokenDataPlaneFilterConfig cfg) {
        Set<String> ids = new HashSet<>();
        List<Client> clients = new ArrayList<>(cfg.clients().size());
        for (var c : cfg.clients()) {
            if (ids.contains(c.id()))
                throw new IllegalArgumentException("Clients definition has duplicate id '%s'".formatted(c.id()));
            if (c.tokens().isEmpty())
                throw new IllegalArgumentException("Client '%s' has no tokens configured".formatted(c.id()));
            ids.add(c.id());
            var tokens = new HashMap<TokenCheckHash, TokenVersion>();
            for (var token : c.tokens()) {
                for (int version = 0; version < token.checkAccessHashes().size(); version++) {
                    var tokenVersion = TokenVersion.of(
                            token.id(), token.fingerprints().get(version), token.checkAccessHashes().get(version),
                            token.expirations().get(version));
                    tokens.put(tokenVersion.accessHash(), tokenVersion);
                }
            }
            clients.add(new Client(c.id(), Permission.setOf(c.permissions()), tokens));
        }
        log.fine(() -> "Configured clients with ids %s".formatted(ids));
        return List.copyOf(clients);
    }

    @Override
    protected Optional<ErrorResponse> filter(DiscFilterRequest req) {
        var now = clock.instant();
        var bearerToken = requestBearerToken(req).orElse(null);
        if (bearerToken == null) {
            log.fine("Missing bearer token");
            return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, "Unauthorized"));
        }
        var permission = Permission.getRequiredPermission(req).orElse(null);
        if (permission == null) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden"));
        var requestTokenHash = requestTokenHash(bearerToken);
        var clientIds = new TreeSet<String>();
        var permissions = EnumSet.noneOf(Permission.class);
        var matchedTokens = new HashSet<TokenVersion>();
        for (Client c : allowedClients) {
            if (!c.permissions().contains(permission)) continue;
            var matchedToken  = c.tokens().get(requestTokenHash);
            if (matchedToken == null) continue;
            var expiration = matchedToken.expiration().orElse(null);
            if (expiration != null && now.isAfter(expiration)) continue;
            matchedTokens.add(matchedToken);
            clientIds.add(c.id());
            permissions.addAll(c.permissions());
        }
        if (clientIds.isEmpty()) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden"));
        if (matchedTokens.size() > 1) {
            log.warning("Multiple tokens matched for request %s"
                                .formatted(matchedTokens.stream().map(TokenVersion::id).toList()));
            return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden"));
        }
        var matchedToken = matchedTokens.stream().findAny().get();
        addAccessLogEntry(req, "token.id", matchedToken.id());
        addAccessLogEntry(req, "token.hash", matchedToken.fingerprint().toDelimitedHexString());
        addAccessLogEntry(req, "token.exp", matchedToken.expiration().map(Instant::toString).orElse("<none>"));
        ClientPrincipal.attachToRequest(req, clientIds, permissions);
        return Optional.empty();
    }

    private TokenCheckHash requestTokenHash(String bearerToken) {
        return TokenCheckHash.of(Token.of(tokenDomain, bearerToken), CHECK_HASH_BYTES);
    }

    private static Optional<String> requestBearerToken(DiscFilterRequest req) {
        return Optional.ofNullable(req.getHeader("Authorization"))
                .filter(h -> h.startsWith("Bearer "))
                .map(t -> t.substring("Bearer ".length()).trim())
                .filter(t -> !t.isBlank());

    }

    private static void addAccessLogEntry(DiscFilterRequest req, String key, String value) {
        ((AccessLogEntry) req.getAttribute(CONTEXT_KEY_ACCESS_LOG_ENTRY)).addKeyValue(key, value);
    }

    private record TokenVersion(String id, TokenFingerprint fingerprint, TokenCheckHash accessHash, Optional<Instant> expiration) {
        static TokenVersion of(String id, String fingerprint, String accessHash, String expiration) {
            return new TokenVersion(id, TokenFingerprint.ofHex(fingerprint), TokenCheckHash.ofHex(accessHash),
                                    expiration.equals("<none>") ? Optional.empty() : Optional.of(Instant.parse(expiration)));
        }
    }

    private record Client(String id, EnumSet<Permission> permissions, Map<TokenCheckHash, TokenVersion> tokens) {
        Client {
            permissions = EnumSet.copyOf(permissions); tokens = Map.copyOf(tokens);
        }
    }
}