diff options
Diffstat (limited to 'jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java')
-rw-r--r-- | jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java index 71f1965c764..7bdc386e4b4 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/rule/RuleBasedRequestFilter.java @@ -3,6 +3,7 @@ package com.yahoo.jdisc.http.filter.security.rule; import com.google.inject.Inject; import com.yahoo.jdisc.Metric; +import com.yahoo.jdisc.Response; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.security.base.JsonSecurityRequestFilterBase; import com.yahoo.jdisc.http.filter.security.rule.RuleBasedFilterConfig.Rule.Action; @@ -56,7 +57,11 @@ public class RuleBasedRequestFilter extends JsonSecurityRequestFilterBase { private static ErrorResponse createDefaultResponse(RuleBasedFilterConfig.DefaultRule defaultRule) { switch (defaultRule.action()) { case ALLOW: return null; - case BLOCK: return new ErrorResponse(defaultRule.blockResponseCode(), defaultRule.blockResponseMessage()); + case BLOCK: { + Response response = new Response(defaultRule.blockResponseCode()); + defaultRule.blockResponseHeaders().forEach(h -> response.headers().add(h.name(), h.value())); + return new ErrorResponse(response, defaultRule.blockResponseMessage()); + } default: throw new IllegalArgumentException(defaultRule.action().name()); } } @@ -100,9 +105,13 @@ public class RuleBasedRequestFilter extends JsonSecurityRequestFilterBase { .map(m -> m.name().toUpperCase()) .collect(Collectors.toSet()); this.pathGlobExpressions = Set.copyOf(config.pathExpressions()); - this.response = config.action() == Action.Enum.BLOCK - ? new ErrorResponse(config.blockResponseCode(), config.blockResponseMessage()) - : null; + this.response = config.action() == Action.Enum.BLOCK ? createResponse(config) : null; + } + + private static ErrorResponse createResponse(RuleBasedFilterConfig.Rule config) { + Response response = new Response(config.blockResponseCode()); + config.blockResponseHeaders().forEach(h -> response.headers().add(h.name(), h.value())); + return new ErrorResponse(response, config.blockResponseMessage()); } boolean matches(String method, URI uri) { |