diff options
Diffstat (limited to 'jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilter.java')
-rw-r--r-- | jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilter.java | 30 |
1 files changed, 25 insertions, 5 deletions
diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilter.java index ad6c82138e1..5b79b806190 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/athenz/AthenzPrincipalFilter.java @@ -31,22 +31,31 @@ import java.util.Set; */ public class AthenzPrincipalFilter extends CorsRequestFilterBase { + private static final String RESULT_ATTRIBUTE_PREFIX = "jdisc-security-filters.athenz-principal-filter.result"; + public static final String RESULT_ERROR_CODE_ATTRIBUTE = RESULT_ATTRIBUTE_PREFIX + ".error.code"; + public static final String RESULT_ERROR_MESSAGE_ATTRIBUTE = RESULT_ATTRIBUTE_PREFIX + ".error.message"; + public static final String RESULT_PRINCIPAL = RESULT_ATTRIBUTE_PREFIX + ".principal"; + private final NTokenValidator validator; private final String principalTokenHeader; + private final boolean passthroughMode; @Inject public AthenzPrincipalFilter(AthenzPrincipalFilterConfig athenzPrincipalFilterConfig, CorsFilterConfig corsConfig) { this(new NTokenValidator(Paths.get(athenzPrincipalFilterConfig.athenzConfFile())), athenzPrincipalFilterConfig.principalHeaderName(), - new HashSet<>(corsConfig.allowedUrls())); + new HashSet<>(corsConfig.allowedUrls()), + athenzPrincipalFilterConfig.passthroughMode()); } AthenzPrincipalFilter(NTokenValidator validator, String principalTokenHeader, - Set<String> corsAllowedUrls) { + Set<String> corsAllowedUrls, + boolean passthroughMode) { super(corsAllowedUrls); this.validator = validator; this.principalTokenHeader = principalTokenHeader; + this.passthroughMode = passthroughMode; } @Override @@ -61,7 +70,7 @@ public class AthenzPrincipalFilter extends CorsRequestFilterBase { if (!certificatePrincipal.isPresent() && !nTokenPrincipal.isPresent()) { String errorMessage = "Unable to authenticate Athenz identity. " + "Either client certificate or principal token is required."; - return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, errorMessage)); + return createResponse(request, Response.Status.UNAUTHORIZED, errorMessage); } if (certificatePrincipal.isPresent() && nTokenPrincipal.isPresent() && !certificatePrincipal.get().getIdentity().equals(nTokenPrincipal.get().getIdentity())) { @@ -69,14 +78,15 @@ public class AthenzPrincipalFilter extends CorsRequestFilterBase { "Identity in principal token does not match x509 CN: token-identity=%s, cert-identity=%s", nTokenPrincipal.get().getIdentity().getFullName(), certificatePrincipal.get().getIdentity().getFullName()); - return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, errorMessage)); + return createResponse(request, Response.Status.UNAUTHORIZED, errorMessage); } AthenzPrincipal principal = nTokenPrincipal.orElseGet(certificatePrincipal::get); request.setUserPrincipal(principal); request.setRemoteUser(principal.getName()); + request.setAttribute(RESULT_PRINCIPAL, principal); return Optional.empty(); } catch (Exception e) { - return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, e.getMessage())); + return createResponse(request, Response.Status.UNAUTHORIZED, e.getMessage()); } } @@ -92,4 +102,14 @@ public class AthenzPrincipalFilter extends CorsRequestFilterBase { .map(NToken::new); } + private Optional<ErrorResponse> createResponse(DiscFilterRequest request, int statusCode, String message) { + request.setAttribute(RESULT_ERROR_CODE_ATTRIBUTE, statusCode); + request.setAttribute(RESULT_ERROR_MESSAGE_ATTRIBUTE, message); + if (passthroughMode) { + return Optional.empty(); + } else { + return Optional.of(new ErrorResponse(statusCode, message)); + } + } + } |