aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2018-06-20 16:23:38 +0200
committerMartin Polden <mpolden@mpolden.no>2018-06-21 13:10:48 +0200
commit63bae0f0d7d6d3e381c2760ed6be5613637c6268 (patch)
tree35ec6e98b7f6732c410262399f4e0c0a49e7c57e
parent02282b4dff1d753ecebf359f95f9582912f845e4 (diff)
Remove Mockito usage from AthenzPrincipalFilterTest
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java124
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java8
2 files changed, 73 insertions, 59 deletions
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java
index cec3930f9dd..93f3a4e39db 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.controller.athenz.filter;
+import com.yahoo.application.container.handler.Request;
import com.yahoo.jdisc.Response;
import com.yahoo.jdisc.handler.ContentChannel;
import com.yahoo.jdisc.handler.ReadableContentChannel;
@@ -14,6 +15,7 @@ import com.yahoo.vespa.athenz.tls.KeyAlgorithm;
import com.yahoo.vespa.athenz.tls.KeyUtils;
import com.yahoo.vespa.athenz.tls.X509CertificateBuilder;
import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException;
+import com.yahoo.vespa.hosted.controller.restapi.ApplicationRequestToDiscFilterRequestWrapper;
import org.junit.Before;
import org.junit.Test;
@@ -26,22 +28,22 @@ import java.security.KeyPair;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
import java.util.Set;
import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED;
import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA;
-import static java.util.Collections.emptyList;
import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.joining;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
/**
* @author bjorncs
@@ -55,111 +57,94 @@ public class AthenzPrincipalFilterTest {
private static final String ORIGIN = "http://localhost";
private static final Set<String> CORS_ALLOWED_URLS = singleton(ORIGIN);
- private NTokenValidator validator;
+ private NTokenValidatorMock validator;
+ private ResponseHandlerMock responseHandler;
@Before
public void before() {
- validator = mock(NTokenValidator.class);
+ this.validator = new NTokenValidatorMock();
+ this.responseHandler = new ResponseHandlerMock();
}
@Test
public void valid_ntoken_is_accepted() {
- DiscFilterRequest request = createRequestMock();
+ Request request = defaultRequest();
+
AthenzPrincipal principal = new AthenzPrincipal(IDENTITY, NTOKEN);
- when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
- when(request.getClientCertificateChain()).thenReturn(emptyList());
- when(validator.validate(NTOKEN)).thenReturn(principal);
+ validator.add(NTOKEN, principal);
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS);
- filter.filter(request, new ResponseHandlerMock());
-
- verify(request).setUserPrincipal(principal);
- }
+ DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request);
+ filter.filter(filterRequest, new ResponseHandlerMock());
- private DiscFilterRequest createRequestMock() {
- DiscFilterRequest request = mock(DiscFilterRequest.class);
- when(request.getHeader("Origin")).thenReturn(ORIGIN);
- return request;
+ assertEquals(principal, filterRequest.getUserPrincipal());
}
@Test
public void missing_token_and_certificate_is_unauthorized() {
- DiscFilterRequest request = createRequestMock();
- when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null);
- when(request.getClientCertificateChain()).thenReturn(emptyList());
-
- ResponseHandlerMock responseHandler = new ResponseHandlerMock();
-
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS);
- filter.filter(request, responseHandler);
+ DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(new Request("/"));
+ filter.filter(filterRequest, responseHandler);
assertUnauthorized(responseHandler, "Unable to authenticate Athenz identity");
}
@Test
public void invalid_token_is_unauthorized() {
- DiscFilterRequest request = createRequestMock();
- String errorMessage = "Invalid token";
- when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
- when(request.getClientCertificateChain()).thenReturn(emptyList());
- when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException(errorMessage));
-
- ResponseHandlerMock responseHandler = new ResponseHandlerMock();
+ Request request = defaultRequest();
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS);
- filter.filter(request, responseHandler);
+ DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request);
+ filter.filter(filterRequest, responseHandler);
+ String errorMessage = "Invalid token";
assertUnauthorized(responseHandler, errorMessage);
}
@Test
public void certificate_is_accepted() {
- DiscFilterRequest request = createRequestMock();
- when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null);
- when(request.getClientCertificateChain()).thenReturn(singletonList(CERTIFICATE));
-
- ResponseHandlerMock responseHandler = new ResponseHandlerMock();
-
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS);
- filter.filter(request, responseHandler);
+ DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(new Request("/"), singletonList(CERTIFICATE));
+ filter.filter(filterRequest, responseHandler);
AthenzPrincipal expectedPrincipal = new AthenzPrincipal(IDENTITY);
- verify(request).setUserPrincipal(expectedPrincipal);
+ assertEquals(expectedPrincipal, filterRequest.getUserPrincipal());
}
@Test
public void both_ntoken_and_certificate_is_accepted() {
- DiscFilterRequest request = createRequestMock();
- AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN);
- when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
- when(request.getClientCertificateChain()).thenReturn(singletonList(CERTIFICATE));
- when(validator.validate(NTOKEN)).thenReturn(principalWithToken);
+ Request request = defaultRequest();
- ResponseHandlerMock responseHandler = new ResponseHandlerMock();
+ AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN);
+ validator.add(NTOKEN, principalWithToken);
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS);
- filter.filter(request, responseHandler);
+ DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request, singletonList(CERTIFICATE));
+ filter.filter(filterRequest, responseHandler);
- verify(request).setUserPrincipal(principalWithToken);
+ assertEquals(principalWithToken, filterRequest.getUserPrincipal());
}
@Test
public void conflicting_ntoken_and_certificate_is_unauthorized() {
- DiscFilterRequest request = createRequestMock();
- AthenzUser conflictingIdentity = AthenzUser.fromUserId("mallory");
- when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken());
- when(request.getClientCertificateChain())
- .thenReturn(singletonList(createSelfSignedCertificate(conflictingIdentity)));
- when(validator.validate(NTOKEN)).thenReturn(new AthenzPrincipal(IDENTITY));
-
- ResponseHandlerMock responseHandler = new ResponseHandlerMock();
+ Request request = defaultRequest();
+ validator.add(NTOKEN, new AthenzPrincipal(IDENTITY));
+ AthenzUser conflictingIdentity = AthenzUser.fromUserId("mallory");
+ DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request, singletonList(createSelfSignedCertificate(conflictingIdentity)));
AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS);
- filter.filter(request, responseHandler);
+ filter.filter(filterRequest, responseHandler);
assertUnauthorized(responseHandler, "Identity in principal token does not match x509 CN");
}
+ private static Request defaultRequest() {
+ Request request = new Request("/");
+ request.getHeaders().add("Origin", ORIGIN);
+ request.getHeaders().add(ATHENZ_PRINCIPAL_HEADER, NTOKEN.getRawToken());
+ return request;
+ }
+
private static void assertUnauthorized(ResponseHandlerMock responseHandler, String expectedMessageSubstring) {
assertThat(responseHandler.response, notNullValue());
assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED));
@@ -188,6 +173,29 @@ public class AthenzPrincipalFilterTest {
}
+ private static class NTokenValidatorMock extends NTokenValidator {
+
+ private final Map<NToken, AthenzPrincipal> validTokens = new HashMap<>();
+
+ NTokenValidatorMock() {
+ super((service, keyId) -> Optional.empty());
+ }
+
+ public NTokenValidatorMock add(NToken token, AthenzPrincipal principal) {
+ validTokens.put(token, principal);
+ return this;
+ }
+
+ @Override
+ AthenzPrincipal validate(NToken token) throws InvalidTokenException {
+ if (!validTokens.containsKey(token)) {
+ throw new InvalidTokenException("Invalid token");
+ }
+ return validTokens.get(token);
+ }
+
+ }
+
private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 512);
X500Principal x500Name = new X500Principal("CN="+ identity.getFullName());
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java
index eee0519b12b..4883bde99b1 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java
@@ -27,9 +27,14 @@ import java.util.concurrent.TimeUnit;
public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequest {
private final Request request;
+ private final List<X509Certificate> clientCertificateChain;
private Principal userPrincipal;
public ApplicationRequestToDiscFilterRequestWrapper(Request request) {
+ this(request, Collections.emptyList());
+ }
+
+ public ApplicationRequestToDiscFilterRequestWrapper(Request request, List<X509Certificate> clientCertificateChain) {
super(new ServletOrJdiscHttpRequest() {
@Override
public void copyHeaders(HeaderFields target) {
@@ -93,6 +98,7 @@ public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequ
});
this.request = request;
this.userPrincipal = request.getUserPrincipal().orElse(null);
+ this.clientCertificateChain = clientCertificateChain;
}
public Request getUpdatedRequest() {
@@ -178,7 +184,7 @@ public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequ
@Override
public List<X509Certificate> getClientCertificateChain() {
- return Collections.emptyList();
+ return clientCertificateChain;
}
@Override