diff options
3 files changed, 308 insertions, 1 deletions
diff --git a/jdisc-security-filters/pom.xml b/jdisc-security-filters/pom.xml index 867a32cc170..d4adfd23bac 100644 --- a/jdisc-security-filters/pom.xml +++ b/jdisc-security-filters/pom.xml @@ -47,7 +47,22 @@ <artifactId>mockito-core</artifactId> <scope>test</scope> </dependency> - + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.junit.vintage</groupId> + <artifactId>junit-vintage-engine</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>testutil</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> </dependencies> <build> diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java new file mode 100644 index 00000000000..71f1965c764 --- /dev/null +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java @@ -0,0 +1,118 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter.security.rule; + +import com.google.inject.Inject; +import com.yahoo.jdisc.Metric; +import com.yahoo.jdisc.http.filter.DiscFilterRequest; +import com.yahoo.jdisc.http.filter.security.base.JsonSecurityRequestFilterBase; +import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.Rule.Action; +import com.yahoo.restapi.Path; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Security request filter that filters requests based on host, method and uri path. + * + * @author bjorncs + */ +public class RuleBasedRequestFilter extends JsonSecurityRequestFilterBase { + + private static final Logger log = Logger.getLogger(RuleBasedRequestFilter.class.getName()); + + private final Metric metric; + private final boolean dryrun; + private final List<Rule> rules; + private final ErrorResponse defaultResponse; + + @Inject + public RuleBasedRequestFilter(Metric metric, RuleBasedFilterConfig config) { + this.metric = metric; + this.dryrun = config.dryrun(); + this.rules = Rule.fromConfig(config.rule()); + this.defaultResponse = createDefaultResponse(config.defaultRule()); + } + + @Override + protected Optional<ErrorResponse> filter(DiscFilterRequest request) { + String method = request.getMethod(); + URI uri = request.getUri(); + for (Rule rule : rules) { + if (rule.matches(method, uri)) { + log.log(Level.FINE, () -> + String.format("Request '%h' with method '%s' and uri '%s' matched rule '%s'", request, method, uri, rule.name)); + return responseFor(request, rule.name, rule.response); + } + } + return responseFor(request, "default", defaultResponse); + } + + private static ErrorResponse createDefaultResponse(RuleBasedFilterConfig.DefaultRule defaultRule) { + switch (defaultRule.action()) { + case ALLOW: return null; + case BLOCK: return new ErrorResponse(defaultRule.blockResponseCode(), defaultRule.blockResponseMessage()); + default: throw new IllegalArgumentException(defaultRule.action().name()); + } + } + + private Optional<ErrorResponse> responseFor(DiscFilterRequest request, String ruleName, ErrorResponse response) { + int statusCode = response != null ? response.getResponse().getStatus() : 0; + Metric.Context metricContext = metric.createContext(Map.of( + "rule", ruleName, + "dryrun", Boolean.toString(dryrun), + "statusCode", Integer.toString(statusCode))); + if (response != null) { + metric.add("jdisc.http.filter.rule.blocked_requests", 1L, metricContext); + log.log(Level.FINE, () -> String.format( + "Blocking request '%h' with status code '%d' using rule '%s' (dryrun=%b)", request, statusCode, ruleName, dryrun)); + return dryrun ? Optional.empty() : Optional.of(response); + } else { + metric.add("jdisc.http.filter.rule.allowed_requests", 1L, metricContext); + log.log(Level.FINE, () -> String.format("Allowing request '%h' using rule '%s' (dryrun=%b)", request, ruleName, dryrun)); + return Optional.empty(); + } + } + + private static class Rule { + + final String name; + final Set<String> hostnames; + final Set<String> methods; + final Set<String> pathGlobExpressions; + final ErrorResponse response; + + static List<Rule> fromConfig(List<RuleBasedFilterConfig.Rule> config) { + return config.stream() + .map(Rule::new) + .collect(Collectors.toList()); + } + + Rule(RuleBasedFilterConfig.Rule config) { + this.name = config.name(); + this.hostnames = Set.copyOf(config.hostNames()); + this.methods = config.methods().stream() + .map(m -> m.name().toUpperCase()) + .collect(Collectors.toSet()); + this.pathGlobExpressions = Set.copyOf(config.pathExpressions()); + this.response = config.action() == Action.Enum.BLOCK + ? new ErrorResponse(config.blockResponseCode(), config.blockResponseMessage()) + : null; + } + + boolean matches(String method, URI uri) { + boolean methodMatches = methods.isEmpty() || methods.contains(method.toUpperCase()); + String host = uri.getHost(); + boolean hostnameMatches = hostnames.isEmpty() || (host != null && hostnames.contains(host)); + Path pathMatcher = new Path(uri); + boolean pathMatches = pathGlobExpressions.isEmpty() || pathGlobExpressions.stream().anyMatch(pathMatcher::matches); + return methodMatches && hostnameMatches && pathMatches; + } + + } +} diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java new file mode 100644 index 00000000000..c67d3b430c8 --- /dev/null +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java @@ -0,0 +1,174 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter.security.rule; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.yahoo.container.jdisc.RequestHandlerTestDriver.MockResponseHandler; +import com.yahoo.jdisc.Metric; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.http.filter.DiscFilterRequest; +import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.DefaultRule; +import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.Rule; +import com.yahoo.test.json.JsonTestHelper; +import com.yahoo.vespa.jdk8compat.List; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.URI; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author bjorncs + */ +class RuleBasedRequestFilterTest { + + private static final ObjectMapper jsonMapper = new ObjectMapper(); + + @Test + void matches_rule_that_allows_all_methods_and_paths() { + RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder() + .dryrun(false) + .defaultRule(new DefaultRule.Builder() + .action(DefaultRule.Action.Enum.BLOCK)) + .rule(new Rule.Builder() + .name("first") + .hostNames("myserver") + .pathExpressions(List.of()) + .methods(List.of()) + .action(Rule.Action.Enum.ALLOW)) + .build(); + + Metric metric = mock(Metric.class); + RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config); + MockResponseHandler responseHandler = new MockResponseHandler(); + filter.filter(request("PATCH", "http://myserver:80/path-to-resource"), responseHandler); + + assertAllowed(responseHandler, metric); + + } + + @Test + void performs_action_on_first_matching_rule() throws IOException { + RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder() + .dryrun(false) + .defaultRule(new DefaultRule.Builder() + .action(DefaultRule.Action.Enum.ALLOW)) + .rule(new Rule.Builder() + .name("first") + .pathExpressions("/path-to-resource") + .methods(Rule.Methods.Enum.DELETE) + .action(Rule.Action.Enum.BLOCK) + .blockResponseCode(403)) + .rule(new Rule.Builder() + .name("second") + .pathExpressions("/path-to-resource") + .methods(Rule.Methods.Enum.GET) + .action(Rule.Action.Enum.BLOCK) + .blockResponseCode(404)) + .build(); + + Metric metric = mock(Metric.class); + RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config); + MockResponseHandler responseHandler = new MockResponseHandler(); + filter.filter(request("GET", "http://myserver:80/path-to-resource"), responseHandler); + + assertBlocked(responseHandler, metric, 404, ""); + } + + @Test + void performs_default_action_if_no_rule_matches() throws IOException { + RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder() + .dryrun(false) + .defaultRule(new DefaultRule.Builder() + .action(DefaultRule.Action.Enum.BLOCK) + .blockResponseCode(403) + .blockResponseMessage("my custom message")) + .rule(new Rule.Builder() + .name("rule") + .pathExpressions("/path-to-resource") + .methods(Rule.Methods.Enum.GET) + .action(Rule.Action.Enum.ALLOW)) + .build(); + + Metric metric = mock(Metric.class); + RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config); + MockResponseHandler responseHandler = new MockResponseHandler(); + filter.filter(request("POST", "http://myserver:80/"), responseHandler); + + assertBlocked(responseHandler, metric, 403, "my custom message"); + } + + @Test + void matches_rule_with_multiple_alternatives_for_host_path_and_method() throws IOException { + RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder() + .dryrun(false) + .defaultRule(new DefaultRule.Builder() + .action(DefaultRule.Action.Enum.ALLOW)) + .rule(new Rule.Builder() + .name("rule") + .hostNames(Set.of("server1", "server2", "server3")) + .pathExpressions(Set.of("/path-to-resource/{*}", "/another-path")) + .methods(Set.of(Rule.Methods.Enum.GET, Rule.Methods.POST, Rule.Methods.DELETE)) + .action(Rule.Action.Enum.BLOCK) + .blockResponseCode(404) + .blockResponseMessage("not found")) + .build(); + + Metric metric = mock(Metric.class); + RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config); + MockResponseHandler responseHandler = new MockResponseHandler(); + filter.filter(request("POST", "https://server1:443/path-to-resource/id/1/subid/2"), responseHandler); + + assertBlocked(responseHandler, metric, 404, "not found"); + } + + @Test + void no_filtering_if_request_is_allowed() { + RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder() + .dryrun(false) + .defaultRule(new DefaultRule.Builder() + .action(DefaultRule.Action.Enum.ALLOW)) + .build(); + + Metric metric = mock(Metric.class); + RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config); + MockResponseHandler responseHandler = new MockResponseHandler(); + filter.filter(request("DELETE", "http://myserver:80/"), responseHandler); + + assertAllowed(responseHandler, metric); + } + + private static DiscFilterRequest request(String method, String uri) { + DiscFilterRequest request = mock(DiscFilterRequest.class); + when(request.getMethod()).thenReturn(method); + when(request.getUri()).thenReturn(URI.create(uri)); + return request; + } + + private static void assertAllowed(MockResponseHandler handler, Metric metric) { + verify(metric).add(eq("jdisc.http.filter.rule.allowed_requests"), eq(1L), any()); + assertNull(handler.getResponse()); + } + + private static void assertBlocked(MockResponseHandler handler, Metric metric, int expectedCode, String expectedMessage) throws IOException { + verify(metric).add(eq("jdisc.http.filter.rule.blocked_requests"), eq(1L), any()); + Response response = handler.getResponse(); + assertNotNull(response); + assertEquals(expectedCode, response.getStatus()); + ObjectNode expectedJson = jsonMapper.createObjectNode(); + expectedJson.put("message", expectedMessage).put("code", expectedCode); + JsonNode actualJson = jsonMapper.readTree(handler.readAll().getBytes()); + JsonTestHelper.assertJsonEquals(expectedJson, actualJson); + } + +}
\ No newline at end of file |