diff options
author | Valerij Fredriksen <valerijf@yahooinc.com> | 2023-06-06 15:52:41 +0200 |
---|---|---|
committer | Valerij Fredriksen <valerijf@yahooinc.com> | 2023-06-06 15:52:41 +0200 |
commit | a5c36c88fe03eb16908e7066df2be7fc08fef7ce (patch) | |
tree | e5732e52acca99978e1fd3bb0c3f803bfd8d8863 /jdisc-security-filters | |
parent | 212a1934ff38662183609827ac91a67a34179eb0 (diff) |
Allow subdomains in CORS filters
Diffstat (limited to 'jdisc-security-filters')
5 files changed, 91 insertions, 21 deletions
diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogic.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogic.java index e261f420e1c..f24778d1241 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogic.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogic.java @@ -2,15 +2,19 @@ package com.yahoo.jdisc.http.filter.security.cors; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; +import java.util.regex.Pattern; /** * @author bjorncs */ class CorsLogic { - private CorsLogic() {} static final String CORS_PREFLIGHT_REQUEST_CACHE_TTL = Long.toString(Duration.ofDays(7).getSeconds()); @@ -25,23 +29,49 @@ class CorsLogic { "Vary", "*" ); - static Map<String, String> createCorsResponseHeaders(String requestOriginHeader, - Set<String> allowedOrigins) { + private final boolean allowAnyOrigin; + private final Set<String> allowedOrigins; + private final List<Pattern> allowedOriginPatterns; + private CorsLogic(boolean allowAnyOrigin, Set<String> allowedOrigins, List<Pattern> allowedOriginPatterns) { + this.allowAnyOrigin = allowAnyOrigin; + this.allowedOrigins = Set.copyOf(allowedOrigins); + this.allowedOriginPatterns = List.copyOf(allowedOriginPatterns); + } + + boolean originMatches(String origin) { + if (allowAnyOrigin) return true; + if (allowedOrigins.contains(origin)) return true; + return allowedOriginPatterns.stream().anyMatch(pattern -> pattern.matcher(origin).matches()); + } + + Map<String, String> createCorsResponseHeaders(String requestOriginHeader) { if (requestOriginHeader == null) return Map.of(); TreeMap<String, String> headers = new TreeMap<>(); - if (requestOriginMatchesAnyAllowed(requestOriginHeader, allowedOrigins)) + if (originMatches(requestOriginHeader)) headers.put(ALLOW_ORIGIN_HEADER, requestOriginHeader); headers.putAll(ACCESS_CONTROL_HEADERS); return headers; } - static Map<String, String> createCorsPreflightResponseHeaders(String requestOriginHeader, - Set<String> allowedOrigins) { - return createCorsResponseHeaders(requestOriginHeader, allowedOrigins); + Map<String, String> preflightResponseHeaders(String requestOriginHeader) { + return createCorsResponseHeaders(requestOriginHeader); } - private static boolean requestOriginMatchesAnyAllowed(String requestOrigin, Set<String> allowedUrls) { - return allowedUrls.stream().anyMatch(requestOrigin::equals) || allowedUrls.contains("*"); + static CorsLogic forAllowedOrigins(Collection<String> allowedOrigins) { + Set<String> allowedOriginsVerbatim = new HashSet<>(); + List<Pattern> allowedOriginPatterns = new ArrayList<>(); + for (String allowedOrigin : allowedOrigins) { + if (allowedOrigin.isBlank()) continue; + if (allowedOrigin.length() > 0) { + if ("*".equals(allowedOrigin)) + return new CorsLogic(true, Set.of(), List.of()); + else if (allowedOrigin.contains("*")) + allowedOriginPatterns.add(Pattern.compile(allowedOrigin.replace(".", "\\.").replace("*", ".*"))); + else + allowedOriginsVerbatim.add(allowedOrigin); + } + } + return new CorsLogic(false, allowedOriginsVerbatim, allowedOriginPatterns); } } diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsPreflightRequestFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsPreflightRequestFilter.java index e2efd2d220c..935e738b5e3 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsPreflightRequestFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsPreflightRequestFilter.java @@ -10,8 +10,6 @@ import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; import com.yahoo.yolean.chain.Provides; -import java.util.Set; - import static com.yahoo.jdisc.http.HttpRequest.Method.OPTIONS; /** @@ -33,11 +31,11 @@ import static com.yahoo.jdisc.http.HttpRequest.Method.OPTIONS; */ @Provides("CorsPreflightRequestFilter") public class CorsPreflightRequestFilter implements SecurityRequestFilter { - private final Set<String> allowedUrls; + private final CorsLogic cors; @Inject public CorsPreflightRequestFilter(CorsFilterConfig config) { - this.allowedUrls = Set.copyOf(config.allowedUrls()); + this.cors = CorsLogic.forAllowedOrigins(config.allowedUrls()); } @Override @@ -46,8 +44,7 @@ public class CorsPreflightRequestFilter implements SecurityRequestFilter { return; HttpResponse response = HttpResponse.newInstance(Response.Status.OK); - String origin = discFilterRequest.getHeader("Origin"); - CorsLogic.createCorsPreflightResponseHeaders(origin, allowedUrls) + cors.preflightResponseHeaders(discFilterRequest.getHeader("Origin")) .forEach(response.headers()::put); ContentChannel cc = responseHandler.handleResponse(response); diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilter.java index f56965ea6a8..4b6c7211d11 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilter.java @@ -8,9 +8,6 @@ import com.yahoo.jdisc.http.filter.RequestView; import com.yahoo.jdisc.http.filter.SecurityResponseFilter; import com.yahoo.yolean.chain.Provides; -import java.util.Set; - - /** * @author gv * @author Tony Vaagenes @@ -19,16 +16,16 @@ import java.util.Set; @Provides("CorsResponseFilter") public class CorsResponseFilter extends AbstractResource implements SecurityResponseFilter { - private final Set<String> allowedUrls; + private final CorsLogic cors; @Inject public CorsResponseFilter(CorsFilterConfig config) { - this.allowedUrls = Set.copyOf(config.allowedUrls()); + this.cors = CorsLogic.forAllowedOrigins(config.allowedUrls()); } @Override public void filter(DiscFilterResponse response, RequestView request) { - CorsLogic.createCorsResponseHeaders(request.getFirstHeader("Origin").orElse(null), allowedUrls) + cors.createCorsResponseHeaders(request.getFirstHeader("Origin").orElse(null)) .forEach(response::setHeader); } diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogicTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogicTest.java new file mode 100644 index 00000000000..60b5edde97d --- /dev/null +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsLogicTest.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter.security.cors; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author freva + */ +class CorsLogicTest { + + @Test + void wildcard_matches_everything() { + CorsLogic logic = CorsLogic.forAllowedOrigins(List.of("*")); + assertMatches(logic, true, "http://any.origin", "https://any.origin", "http://any.origin:8080"); + } + + @Test + void matches_verbatim_and_pattern() { + CorsLogic logic = CorsLogic.forAllowedOrigins(List.of("http://my.origin", "http://*.domain.origin", "*://do.main", "*.tld")); + assertMatches(logic, true, + "http://my.origin", // Matches verbatim + "http://any.domain.origin", // Matches first pattern + "http://any.sub.domain.origin", // Matches first pattern + "http://do.main", "https://do.main", // Matches second pattern + "https://any.thing.tld"); // Matches third pattern + assertMatches(logic, false, + "https://my.origin", // Different scheme from verbatim + "http://domain.origin", // Missing subdomain to match the first pattern + "https://sub.do.main"); // Second pattern, but with subdomain + } + + private static void assertMatches(CorsLogic logic, boolean expected, String... origins) { + for (String origin : origins) + assertEquals(expected, logic.originMatches(origin), origin); + } +} diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilterTest.java index 7762fde1a72..1fded811eed 100644 --- a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilterTest.java +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cors/CorsResponseFilterTest.java @@ -54,6 +54,12 @@ public class CorsResponseFilterTest { assertEquals("http://any.origin", headers.get(ALLOW_ORIGIN_HEADER)); } + @Test + void matches_subdomains() { + Map<String, String> headers = doFilterRequest(newResponseFilter("http://*.domain.origin"), "http://any.domain.origin"); + assertEquals("http://any.domain.origin", headers.get(ALLOW_ORIGIN_HEADER)); + } + private static Map<String, String> doFilterRequest(SecurityResponseFilter filter, String originUrl) { TestResponse response = new TestResponse(); filter.filter(response, newRequestView(originUrl)); |