diff options
author | Martin Polden <mpolden@mpolden.no> | 2018-06-20 16:23:38 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2018-06-21 13:10:48 +0200 |
commit | 63bae0f0d7d6d3e381c2760ed6be5613637c6268 (patch) | |
tree | 35ec6e98b7f6732c410262399f4e0c0a49e7c57e | |
parent | 02282b4dff1d753ecebf359f95f9582912f845e4 (diff) |
Remove Mockito usage from AthenzPrincipalFilterTest
2 files changed, 73 insertions, 59 deletions
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java index cec3930f9dd..93f3a4e39db 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.athenz.filter; +import com.yahoo.application.container.handler.Request; import com.yahoo.jdisc.Response; import com.yahoo.jdisc.handler.ContentChannel; import com.yahoo.jdisc.handler.ReadableContentChannel; @@ -14,6 +15,7 @@ import com.yahoo.vespa.athenz.tls.KeyAlgorithm; import com.yahoo.vespa.athenz.tls.KeyUtils; import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException; +import com.yahoo.vespa.hosted.controller.restapi.ApplicationRequestToDiscFilterRequestWrapper; import org.junit.Before; import org.junit.Test; @@ -26,22 +28,22 @@ import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; -import static java.util.Collections.emptyList; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.joining; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author bjorncs @@ -55,111 +57,94 @@ public class AthenzPrincipalFilterTest { private static final String ORIGIN = "http://localhost"; private static final Set<String> CORS_ALLOWED_URLS = singleton(ORIGIN); - private NTokenValidator validator; + private NTokenValidatorMock validator; + private ResponseHandlerMock responseHandler; @Before public void before() { - validator = mock(NTokenValidator.class); + this.validator = new NTokenValidatorMock(); + this.responseHandler = new ResponseHandlerMock(); } @Test public void valid_ntoken_is_accepted() { - DiscFilterRequest request = createRequestMock(); + Request request = defaultRequest(); + AthenzPrincipal principal = new AthenzPrincipal(IDENTITY, NTOKEN); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()).thenReturn(emptyList()); - when(validator.validate(NTOKEN)).thenReturn(principal); + validator.add(NTOKEN, principal); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, new ResponseHandlerMock()); - - verify(request).setUserPrincipal(principal); - } + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request); + filter.filter(filterRequest, new ResponseHandlerMock()); - private DiscFilterRequest createRequestMock() { - DiscFilterRequest request = mock(DiscFilterRequest.class); - when(request.getHeader("Origin")).thenReturn(ORIGIN); - return request; + assertEquals(principal, filterRequest.getUserPrincipal()); } @Test public void missing_token_and_certificate_is_unauthorized() { - DiscFilterRequest request = createRequestMock(); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null); - when(request.getClientCertificateChain()).thenReturn(emptyList()); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); - AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(new Request("/")); + filter.filter(filterRequest, responseHandler); assertUnauthorized(responseHandler, "Unable to authenticate Athenz identity"); } @Test public void invalid_token_is_unauthorized() { - DiscFilterRequest request = createRequestMock(); - String errorMessage = "Invalid token"; - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()).thenReturn(emptyList()); - when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException(errorMessage)); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + Request request = defaultRequest(); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request); + filter.filter(filterRequest, responseHandler); + String errorMessage = "Invalid token"; assertUnauthorized(responseHandler, errorMessage); } @Test public void certificate_is_accepted() { - DiscFilterRequest request = createRequestMock(); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null); - when(request.getClientCertificateChain()).thenReturn(singletonList(CERTIFICATE)); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); - AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(new Request("/"), singletonList(CERTIFICATE)); + filter.filter(filterRequest, responseHandler); AthenzPrincipal expectedPrincipal = new AthenzPrincipal(IDENTITY); - verify(request).setUserPrincipal(expectedPrincipal); + assertEquals(expectedPrincipal, filterRequest.getUserPrincipal()); } @Test public void both_ntoken_and_certificate_is_accepted() { - DiscFilterRequest request = createRequestMock(); - AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()).thenReturn(singletonList(CERTIFICATE)); - when(validator.validate(NTOKEN)).thenReturn(principalWithToken); + Request request = defaultRequest(); - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN); + validator.add(NTOKEN, principalWithToken); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request, singletonList(CERTIFICATE)); + filter.filter(filterRequest, responseHandler); - verify(request).setUserPrincipal(principalWithToken); + assertEquals(principalWithToken, filterRequest.getUserPrincipal()); } @Test public void conflicting_ntoken_and_certificate_is_unauthorized() { - DiscFilterRequest request = createRequestMock(); - AthenzUser conflictingIdentity = AthenzUser.fromUserId("mallory"); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()) - .thenReturn(singletonList(createSelfSignedCertificate(conflictingIdentity))); - when(validator.validate(NTOKEN)).thenReturn(new AthenzPrincipal(IDENTITY)); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + Request request = defaultRequest(); + validator.add(NTOKEN, new AthenzPrincipal(IDENTITY)); + AthenzUser conflictingIdentity = AthenzUser.fromUserId("mallory"); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request, singletonList(createSelfSignedCertificate(conflictingIdentity))); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + filter.filter(filterRequest, responseHandler); assertUnauthorized(responseHandler, "Identity in principal token does not match x509 CN"); } + private static Request defaultRequest() { + Request request = new Request("/"); + request.getHeaders().add("Origin", ORIGIN); + request.getHeaders().add(ATHENZ_PRINCIPAL_HEADER, NTOKEN.getRawToken()); + return request; + } + private static void assertUnauthorized(ResponseHandlerMock responseHandler, String expectedMessageSubstring) { assertThat(responseHandler.response, notNullValue()); assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED)); @@ -188,6 +173,29 @@ public class AthenzPrincipalFilterTest { } + private static class NTokenValidatorMock extends NTokenValidator { + + private final Map<NToken, AthenzPrincipal> validTokens = new HashMap<>(); + + NTokenValidatorMock() { + super((service, keyId) -> Optional.empty()); + } + + public NTokenValidatorMock add(NToken token, AthenzPrincipal principal) { + validTokens.put(token, principal); + return this; + } + + @Override + AthenzPrincipal validate(NToken token) throws InvalidTokenException { + if (!validTokens.containsKey(token)) { + throw new InvalidTokenException("Invalid token"); + } + return validTokens.get(token); + } + + } + private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 512); X500Principal x500Name = new X500Principal("CN="+ identity.getFullName()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java index eee0519b12b..4883bde99b1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java @@ -27,9 +27,14 @@ import java.util.concurrent.TimeUnit; public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequest { private final Request request; + private final List<X509Certificate> clientCertificateChain; private Principal userPrincipal; public ApplicationRequestToDiscFilterRequestWrapper(Request request) { + this(request, Collections.emptyList()); + } + + public ApplicationRequestToDiscFilterRequestWrapper(Request request, List<X509Certificate> clientCertificateChain) { super(new ServletOrJdiscHttpRequest() { @Override public void copyHeaders(HeaderFields target) { @@ -93,6 +98,7 @@ public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequ }); this.request = request; this.userPrincipal = request.getUserPrincipal().orElse(null); + this.clientCertificateChain = clientCertificateChain; } public Request getUpdatedRequest() { @@ -178,7 +184,7 @@ public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequ @Override public List<X509Certificate> getClientCertificateChain() { - return Collections.emptyList(); + return clientCertificateChain; } @Override |