aboutsummaryrefslogtreecommitdiffstats
path: root/jdisc-security-filters
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@verizonmedia.com>2021-02-17 16:02:08 +0100
committerBjørn Christian Seime <bjorncs@verizonmedia.com>2021-02-17 16:04:10 +0100
commit4717ca675a011455fa68ec12cacfc26033a434a6 (patch)
treeab9e5f68acd5b0cc56fd3684f3f717c0b0e94fad /jdisc-security-filters
parent2e55f8118174a1e6fe5faa5ca9daf88f4be82461 (diff)
Add rule based request filter
Diffstat (limited to 'jdisc-security-filters')
-rw-r--r--jdisc-security-filters/pom.xml17
-rw-r--r--jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java118
-rw-r--r--jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java174
3 files changed, 308 insertions, 1 deletions
diff --git a/jdisc-security-filters/pom.xml b/jdisc-security-filters/pom.xml
index 867a32cc170..d4adfd23bac 100644
--- a/jdisc-security-filters/pom.xml
+++ b/jdisc-security-filters/pom.xml
@@ -47,7 +47,22 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
-
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.junit.vintage</groupId>
+ <artifactId>junit-vintage-engine</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>testutil</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java
new file mode 100644
index 00000000000..71f1965c764
--- /dev/null
+++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java
@@ -0,0 +1,118 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.jdisc.http.filter.security.rule;
+
+import com.google.inject.Inject;
+import com.yahoo.jdisc.Metric;
+import com.yahoo.jdisc.http.filter.DiscFilterRequest;
+import com.yahoo.jdisc.http.filter.security.base.JsonSecurityRequestFilterBase;
+import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.Rule.Action;
+import com.yahoo.restapi.Path;
+
+import java.net.URI;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Security request filter that filters requests based on host, method and uri path.
+ *
+ * @author bjorncs
+ */
+public class RuleBasedRequestFilter extends JsonSecurityRequestFilterBase {
+
+ private static final Logger log = Logger.getLogger(RuleBasedRequestFilter.class.getName());
+
+ private final Metric metric;
+ private final boolean dryrun;
+ private final List<Rule> rules;
+ private final ErrorResponse defaultResponse;
+
+ @Inject
+ public RuleBasedRequestFilter(Metric metric, RuleBasedFilterConfig config) {
+ this.metric = metric;
+ this.dryrun = config.dryrun();
+ this.rules = Rule.fromConfig(config.rule());
+ this.defaultResponse = createDefaultResponse(config.defaultRule());
+ }
+
+ @Override
+ protected Optional<ErrorResponse> filter(DiscFilterRequest request) {
+ String method = request.getMethod();
+ URI uri = request.getUri();
+ for (Rule rule : rules) {
+ if (rule.matches(method, uri)) {
+ log.log(Level.FINE, () ->
+ String.format("Request '%h' with method '%s' and uri '%s' matched rule '%s'", request, method, uri, rule.name));
+ return responseFor(request, rule.name, rule.response);
+ }
+ }
+ return responseFor(request, "default", defaultResponse);
+ }
+
+ private static ErrorResponse createDefaultResponse(RuleBasedFilterConfig.DefaultRule defaultRule) {
+ switch (defaultRule.action()) {
+ case ALLOW: return null;
+ case BLOCK: return new ErrorResponse(defaultRule.blockResponseCode(), defaultRule.blockResponseMessage());
+ default: throw new IllegalArgumentException(defaultRule.action().name());
+ }
+ }
+
+ private Optional<ErrorResponse> responseFor(DiscFilterRequest request, String ruleName, ErrorResponse response) {
+ int statusCode = response != null ? response.getResponse().getStatus() : 0;
+ Metric.Context metricContext = metric.createContext(Map.of(
+ "rule", ruleName,
+ "dryrun", Boolean.toString(dryrun),
+ "statusCode", Integer.toString(statusCode)));
+ if (response != null) {
+ metric.add("jdisc.http.filter.rule.blocked_requests", 1L, metricContext);
+ log.log(Level.FINE, () -> String.format(
+ "Blocking request '%h' with status code '%d' using rule '%s' (dryrun=%b)", request, statusCode, ruleName, dryrun));
+ return dryrun ? Optional.empty() : Optional.of(response);
+ } else {
+ metric.add("jdisc.http.filter.rule.allowed_requests", 1L, metricContext);
+ log.log(Level.FINE, () -> String.format("Allowing request '%h' using rule '%s' (dryrun=%b)", request, ruleName, dryrun));
+ return Optional.empty();
+ }
+ }
+
+ private static class Rule {
+
+ final String name;
+ final Set<String> hostnames;
+ final Set<String> methods;
+ final Set<String> pathGlobExpressions;
+ final ErrorResponse response;
+
+ static List<Rule> fromConfig(List<RuleBasedFilterConfig.Rule> config) {
+ return config.stream()
+ .map(Rule::new)
+ .collect(Collectors.toList());
+ }
+
+ Rule(RuleBasedFilterConfig.Rule config) {
+ this.name = config.name();
+ this.hostnames = Set.copyOf(config.hostNames());
+ this.methods = config.methods().stream()
+ .map(m -> m.name().toUpperCase())
+ .collect(Collectors.toSet());
+ this.pathGlobExpressions = Set.copyOf(config.pathExpressions());
+ this.response = config.action() == Action.Enum.BLOCK
+ ? new ErrorResponse(config.blockResponseCode(), config.blockResponseMessage())
+ : null;
+ }
+
+ boolean matches(String method, URI uri) {
+ boolean methodMatches = methods.isEmpty() || methods.contains(method.toUpperCase());
+ String host = uri.getHost();
+ boolean hostnameMatches = hostnames.isEmpty() || (host != null && hostnames.contains(host));
+ Path pathMatcher = new Path(uri);
+ boolean pathMatches = pathGlobExpressions.isEmpty() || pathGlobExpressions.stream().anyMatch(pathMatcher::matches);
+ return methodMatches && hostnameMatches && pathMatches;
+ }
+
+ }
+}
diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java
new file mode 100644
index 00000000000..c67d3b430c8
--- /dev/null
+++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilterTest.java
@@ -0,0 +1,174 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.jdisc.http.filter.security.rule;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import com.yahoo.container.jdisc.RequestHandlerTestDriver.MockResponseHandler;
+import com.yahoo.jdisc.Metric;
+import com.yahoo.jdisc.Response;
+import com.yahoo.jdisc.http.filter.DiscFilterRequest;
+import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.DefaultRule;
+import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.Rule;
+import com.yahoo.test.json.JsonTestHelper;
+import com.yahoo.vespa.jdk8compat.List;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Set;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author bjorncs
+ */
+class RuleBasedRequestFilterTest {
+
+ private static final ObjectMapper jsonMapper = new ObjectMapper();
+
+ @Test
+ void matches_rule_that_allows_all_methods_and_paths() {
+ RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder()
+ .dryrun(false)
+ .defaultRule(new DefaultRule.Builder()
+ .action(DefaultRule.Action.Enum.BLOCK))
+ .rule(new Rule.Builder()
+ .name("first")
+ .hostNames("myserver")
+ .pathExpressions(List.of())
+ .methods(List.of())
+ .action(Rule.Action.Enum.ALLOW))
+ .build();
+
+ Metric metric = mock(Metric.class);
+ RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config);
+ MockResponseHandler responseHandler = new MockResponseHandler();
+ filter.filter(request("PATCH", "http://myserver:80/path-to-resource"), responseHandler);
+
+ assertAllowed(responseHandler, metric);
+
+ }
+
+ @Test
+ void performs_action_on_first_matching_rule() throws IOException {
+ RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder()
+ .dryrun(false)
+ .defaultRule(new DefaultRule.Builder()
+ .action(DefaultRule.Action.Enum.ALLOW))
+ .rule(new Rule.Builder()
+ .name("first")
+ .pathExpressions("/path-to-resource")
+ .methods(Rule.Methods.Enum.DELETE)
+ .action(Rule.Action.Enum.BLOCK)
+ .blockResponseCode(403))
+ .rule(new Rule.Builder()
+ .name("second")
+ .pathExpressions("/path-to-resource")
+ .methods(Rule.Methods.Enum.GET)
+ .action(Rule.Action.Enum.BLOCK)
+ .blockResponseCode(404))
+ .build();
+
+ Metric metric = mock(Metric.class);
+ RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config);
+ MockResponseHandler responseHandler = new MockResponseHandler();
+ filter.filter(request("GET", "http://myserver:80/path-to-resource"), responseHandler);
+
+ assertBlocked(responseHandler, metric, 404, "");
+ }
+
+ @Test
+ void performs_default_action_if_no_rule_matches() throws IOException {
+ RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder()
+ .dryrun(false)
+ .defaultRule(new DefaultRule.Builder()
+ .action(DefaultRule.Action.Enum.BLOCK)
+ .blockResponseCode(403)
+ .blockResponseMessage("my custom message"))
+ .rule(new Rule.Builder()
+ .name("rule")
+ .pathExpressions("/path-to-resource")
+ .methods(Rule.Methods.Enum.GET)
+ .action(Rule.Action.Enum.ALLOW))
+ .build();
+
+ Metric metric = mock(Metric.class);
+ RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config);
+ MockResponseHandler responseHandler = new MockResponseHandler();
+ filter.filter(request("POST", "http://myserver:80/"), responseHandler);
+
+ assertBlocked(responseHandler, metric, 403, "my custom message");
+ }
+
+ @Test
+ void matches_rule_with_multiple_alternatives_for_host_path_and_method() throws IOException {
+ RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder()
+ .dryrun(false)
+ .defaultRule(new DefaultRule.Builder()
+ .action(DefaultRule.Action.Enum.ALLOW))
+ .rule(new Rule.Builder()
+ .name("rule")
+ .hostNames(Set.of("server1", "server2", "server3"))
+ .pathExpressions(Set.of("/path-to-resource/{*}", "/another-path"))
+ .methods(Set.of(Rule.Methods.Enum.GET, Rule.Methods.POST, Rule.Methods.DELETE))
+ .action(Rule.Action.Enum.BLOCK)
+ .blockResponseCode(404)
+ .blockResponseMessage("not found"))
+ .build();
+
+ Metric metric = mock(Metric.class);
+ RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config);
+ MockResponseHandler responseHandler = new MockResponseHandler();
+ filter.filter(request("POST", "https://server1:443/path-to-resource/id/1/subid/2"), responseHandler);
+
+ assertBlocked(responseHandler, metric, 404, "not found");
+ }
+
+ @Test
+ void no_filtering_if_request_is_allowed() {
+ RuleBasedFilterConfig config = new RuleBasedFilterConfig.Builder()
+ .dryrun(false)
+ .defaultRule(new DefaultRule.Builder()
+ .action(DefaultRule.Action.Enum.ALLOW))
+ .build();
+
+ Metric metric = mock(Metric.class);
+ RuleBasedRequestFilter filter = new RuleBasedRequestFilter(metric, config);
+ MockResponseHandler responseHandler = new MockResponseHandler();
+ filter.filter(request("DELETE", "http://myserver:80/"), responseHandler);
+
+ assertAllowed(responseHandler, metric);
+ }
+
+ private static DiscFilterRequest request(String method, String uri) {
+ DiscFilterRequest request = mock(DiscFilterRequest.class);
+ when(request.getMethod()).thenReturn(method);
+ when(request.getUri()).thenReturn(URI.create(uri));
+ return request;
+ }
+
+ private static void assertAllowed(MockResponseHandler handler, Metric metric) {
+ verify(metric).add(eq("jdisc.http.filter.rule.allowed_requests"), eq(1L), any());
+ assertNull(handler.getResponse());
+ }
+
+ private static void assertBlocked(MockResponseHandler handler, Metric metric, int expectedCode, String expectedMessage) throws IOException {
+ verify(metric).add(eq("jdisc.http.filter.rule.blocked_requests"), eq(1L), any());
+ Response response = handler.getResponse();
+ assertNotNull(response);
+ assertEquals(expectedCode, response.getStatus());
+ ObjectNode expectedJson = jsonMapper.createObjectNode();
+ expectedJson.put("message", expectedMessage).put("code", expectedCode);
+ JsonNode actualJson = jsonMapper.readTree(handler.readAll().getBytes());
+ JsonTestHelper.assertJsonEquals(expectedJson, actualJson);
+ }
+
+} \ No newline at end of file