diff options
5 files changed, 103 insertions, 85 deletions
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/chef/ChefMock.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/chef/ChefMock.java index 1b2dad34b8d..bd19cfe6ce1 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/chef/ChefMock.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/chef/ChefMock.java @@ -16,7 +16,6 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * @author mpolden @@ -24,11 +23,14 @@ import java.util.stream.Collectors; public class ChefMock implements Chef { private final NodeResult result; + private final PartialNodeResult partialResult; private final List<String> chefEnvironments; public ChefMock() { result = new NodeResult(); result.rows = new ArrayList<>(); + partialResult = new PartialNodeResult(); + partialResult.rows = new ArrayList<>(); chefEnvironments = new ArrayList<>(); chefEnvironments.add("hosted-verified-prod"); chefEnvironments.add("hosted-infra-cd"); @@ -59,8 +61,14 @@ public class ChefMock implements Chef { return null; } - public void addSearchResult(ChefNode node) { + public ChefMock addSearchResult(ChefNode node) { result.rows.add(node); + return this; + } + + public ChefMock addPartialResult(List<PartialNode> partialNodes) { + partialResult.rows.addAll(partialNodes); + return this; } @Override @@ -76,13 +84,15 @@ public class ChefMock implements Chef { @Override public PartialNodeResult partialSearchNodes(String query, List<AttributeMapping> returnAttributes) { PartialNodeResult partialNodeResult = new PartialNodeResult(); - partialNodeResult.rows = result.rows.stream() - .map(chefNode -> { - Map<String, String> data = new HashMap<>(); - data.put("fqdn", chefNode.name); - return new PartialNode(data); - }) - .collect(Collectors.toList()); + partialNodeResult.rows = new ArrayList<>(); + partialNodeResult.rows.addAll(partialResult.rows); + result.rows.stream() + .map(chefNode -> { + Map<String, String> data = new HashMap<>(); + data.put("fqdn", chefNode.name); + return new PartialNode(data); + }) + .forEach(node -> partialNodeResult.rows.add(node)); return partialNodeResult; } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java index cec3930f9dd..93f3a4e39db 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/athenz/filter/AthenzPrincipalFilterTest.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.athenz.filter; +import com.yahoo.application.container.handler.Request; import com.yahoo.jdisc.Response; import com.yahoo.jdisc.handler.ContentChannel; import com.yahoo.jdisc.handler.ReadableContentChannel; @@ -14,6 +15,7 @@ import com.yahoo.vespa.athenz.tls.KeyAlgorithm; import com.yahoo.vespa.athenz.tls.KeyUtils; import com.yahoo.vespa.athenz.tls.X509CertificateBuilder; import com.yahoo.vespa.hosted.controller.api.integration.athenz.InvalidTokenException; +import com.yahoo.vespa.hosted.controller.restapi.ApplicationRequestToDiscFilterRequestWrapper; import org.junit.Before; import org.junit.Test; @@ -26,22 +28,22 @@ import java.security.KeyPair; import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; import static com.yahoo.vespa.athenz.tls.SignatureAlgorithm.SHA256_WITH_RSA; -import static java.util.Collections.emptyList; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; import static java.util.stream.Collectors.joining; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * @author bjorncs @@ -55,111 +57,94 @@ public class AthenzPrincipalFilterTest { private static final String ORIGIN = "http://localhost"; private static final Set<String> CORS_ALLOWED_URLS = singleton(ORIGIN); - private NTokenValidator validator; + private NTokenValidatorMock validator; + private ResponseHandlerMock responseHandler; @Before public void before() { - validator = mock(NTokenValidator.class); + this.validator = new NTokenValidatorMock(); + this.responseHandler = new ResponseHandlerMock(); } @Test public void valid_ntoken_is_accepted() { - DiscFilterRequest request = createRequestMock(); + Request request = defaultRequest(); + AthenzPrincipal principal = new AthenzPrincipal(IDENTITY, NTOKEN); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()).thenReturn(emptyList()); - when(validator.validate(NTOKEN)).thenReturn(principal); + validator.add(NTOKEN, principal); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, new ResponseHandlerMock()); - - verify(request).setUserPrincipal(principal); - } + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request); + filter.filter(filterRequest, new ResponseHandlerMock()); - private DiscFilterRequest createRequestMock() { - DiscFilterRequest request = mock(DiscFilterRequest.class); - when(request.getHeader("Origin")).thenReturn(ORIGIN); - return request; + assertEquals(principal, filterRequest.getUserPrincipal()); } @Test public void missing_token_and_certificate_is_unauthorized() { - DiscFilterRequest request = createRequestMock(); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null); - when(request.getClientCertificateChain()).thenReturn(emptyList()); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); - AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(new Request("/")); + filter.filter(filterRequest, responseHandler); assertUnauthorized(responseHandler, "Unable to authenticate Athenz identity"); } @Test public void invalid_token_is_unauthorized() { - DiscFilterRequest request = createRequestMock(); - String errorMessage = "Invalid token"; - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()).thenReturn(emptyList()); - when(validator.validate(NTOKEN)).thenThrow(new InvalidTokenException(errorMessage)); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + Request request = defaultRequest(); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request); + filter.filter(filterRequest, responseHandler); + String errorMessage = "Invalid token"; assertUnauthorized(responseHandler, errorMessage); } @Test public void certificate_is_accepted() { - DiscFilterRequest request = createRequestMock(); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(null); - when(request.getClientCertificateChain()).thenReturn(singletonList(CERTIFICATE)); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); - AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(new Request("/"), singletonList(CERTIFICATE)); + filter.filter(filterRequest, responseHandler); AthenzPrincipal expectedPrincipal = new AthenzPrincipal(IDENTITY); - verify(request).setUserPrincipal(expectedPrincipal); + assertEquals(expectedPrincipal, filterRequest.getUserPrincipal()); } @Test public void both_ntoken_and_certificate_is_accepted() { - DiscFilterRequest request = createRequestMock(); - AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()).thenReturn(singletonList(CERTIFICATE)); - when(validator.validate(NTOKEN)).thenReturn(principalWithToken); + Request request = defaultRequest(); - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + AthenzPrincipal principalWithToken = new AthenzPrincipal(IDENTITY, NTOKEN); + validator.add(NTOKEN, principalWithToken); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request, singletonList(CERTIFICATE)); + filter.filter(filterRequest, responseHandler); - verify(request).setUserPrincipal(principalWithToken); + assertEquals(principalWithToken, filterRequest.getUserPrincipal()); } @Test public void conflicting_ntoken_and_certificate_is_unauthorized() { - DiscFilterRequest request = createRequestMock(); - AthenzUser conflictingIdentity = AthenzUser.fromUserId("mallory"); - when(request.getHeader(ATHENZ_PRINCIPAL_HEADER)).thenReturn(NTOKEN.getRawToken()); - when(request.getClientCertificateChain()) - .thenReturn(singletonList(createSelfSignedCertificate(conflictingIdentity))); - when(validator.validate(NTOKEN)).thenReturn(new AthenzPrincipal(IDENTITY)); - - ResponseHandlerMock responseHandler = new ResponseHandlerMock(); + Request request = defaultRequest(); + validator.add(NTOKEN, new AthenzPrincipal(IDENTITY)); + AthenzUser conflictingIdentity = AthenzUser.fromUserId("mallory"); + DiscFilterRequest filterRequest = new ApplicationRequestToDiscFilterRequestWrapper(request, singletonList(createSelfSignedCertificate(conflictingIdentity))); AthenzPrincipalFilter filter = new AthenzPrincipalFilter(validator, Runnable::run, ATHENZ_PRINCIPAL_HEADER, CORS_ALLOWED_URLS); - filter.filter(request, responseHandler); + filter.filter(filterRequest, responseHandler); assertUnauthorized(responseHandler, "Identity in principal token does not match x509 CN"); } + private static Request defaultRequest() { + Request request = new Request("/"); + request.getHeaders().add("Origin", ORIGIN); + request.getHeaders().add(ATHENZ_PRINCIPAL_HEADER, NTOKEN.getRawToken()); + return request; + } + private static void assertUnauthorized(ResponseHandlerMock responseHandler, String expectedMessageSubstring) { assertThat(responseHandler.response, notNullValue()); assertThat(responseHandler.response.getStatus(), equalTo(UNAUTHORIZED)); @@ -188,6 +173,29 @@ public class AthenzPrincipalFilterTest { } + private static class NTokenValidatorMock extends NTokenValidator { + + private final Map<NToken, AthenzPrincipal> validTokens = new HashMap<>(); + + NTokenValidatorMock() { + super((service, keyId) -> Optional.empty()); + } + + public NTokenValidatorMock add(NToken token, AthenzPrincipal principal) { + validTokens.put(token, principal); + return this; + } + + @Override + AthenzPrincipal validate(NToken token) throws InvalidTokenException { + if (!validTokens.containsKey(token)) { + throw new InvalidTokenException("Invalid token"); + } + return validTokens.get(token); + } + + } + private static X509Certificate createSelfSignedCertificate(AthenzIdentity identity) { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA, 512); X500Principal x500Name = new X500Principal("CN="+ identity.getFullName()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/MetricsReporterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/MetricsReporterTest.java index e189a9243db..3364eed3066 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/MetricsReporterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/MetricsReporterTest.java @@ -10,8 +10,7 @@ import com.yahoo.vespa.hosted.controller.Controller; import com.yahoo.vespa.hosted.controller.ControllerTester; import com.yahoo.vespa.hosted.controller.MetricsMock; import com.yahoo.vespa.hosted.controller.MetricsMock.MapContext; -import com.yahoo.vespa.hosted.controller.api.integration.chef.AttributeMapping; -import com.yahoo.vespa.hosted.controller.api.integration.chef.Chef; +import com.yahoo.vespa.hosted.controller.api.integration.chef.ChefMock; import com.yahoo.vespa.hosted.controller.api.integration.chef.rest.PartialNodeResult; import com.yahoo.vespa.hosted.controller.application.ApplicationPackage; import com.yahoo.vespa.hosted.controller.deployment.ApplicationPackageBuilder; @@ -19,7 +18,6 @@ import com.yahoo.vespa.hosted.controller.deployment.DeploymentTester; import com.yahoo.vespa.hosted.controller.persistence.MockCuratorDb; import org.junit.Before; import org.junit.Test; -import org.mockito.Mockito; import java.io.IOException; import java.io.UncheckedIOException; @@ -38,9 +36,6 @@ import static com.yahoo.vespa.hosted.controller.api.integration.deployment.JobTy import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; -import static org.mockito.Matchers.anyListOf; -import static org.mockito.Matchers.anyString; -import static org.mockito.Mockito.when; /** * @author mortent @@ -170,7 +165,7 @@ public class MetricsReporterTest { private MetricsReporter createReporter(Clock clock, Controller controller, MetricsMock metricsMock, SystemName system) { - Chef client = Mockito.mock(Chef.class); + ChefMock chef = new ChefMock(); PartialNodeResult result; try { result = new ObjectMapper() @@ -179,9 +174,8 @@ public class MetricsReporterTest { } catch (IOException e) { throw new UncheckedIOException(e); } - when(client.partialSearchNodes(anyString(), anyListOf(AttributeMapping.class))).thenReturn(result); - - return new MetricsReporter(controller, metricsMock, client, clock, new JobControl(new MockCuratorDb()), system); + chef.addPartialResult(result.rows); + return new MetricsReporter(controller, metricsMock, chef, clock, new JobControl(new MockCuratorDb()), system); } private Map<MapContext, Map<String, Number>> getMetricsByHost(String hostname) { diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java index eee0519b12b..4883bde99b1 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ApplicationRequestToDiscFilterRequestWrapper.java @@ -27,9 +27,14 @@ import java.util.concurrent.TimeUnit; public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequest { private final Request request; + private final List<X509Certificate> clientCertificateChain; private Principal userPrincipal; public ApplicationRequestToDiscFilterRequestWrapper(Request request) { + this(request, Collections.emptyList()); + } + + public ApplicationRequestToDiscFilterRequestWrapper(Request request, List<X509Certificate> clientCertificateChain) { super(new ServletOrJdiscHttpRequest() { @Override public void copyHeaders(HeaderFields target) { @@ -93,6 +98,7 @@ public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequ }); this.request = request; this.userPrincipal = request.getUserPrincipal().orElse(null); + this.clientCertificateChain = clientCertificateChain; } public Request getUpdatedRequest() { @@ -178,7 +184,7 @@ public class ApplicationRequestToDiscFilterRequestWrapper extends DiscFilterRequ @Override public List<X509Certificate> getClientCertificateChain() { - return Collections.emptyList(); + return clientCertificateChain; } @Override diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java index df3600eae9b..af11923b7f2 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilterTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.restapi.filter; import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.application.container.handler.Request; import com.yahoo.config.provision.TenantName; import com.yahoo.jdisc.http.HttpRequest.Method; import com.yahoo.jdisc.http.filter.DiscFilterRequest; @@ -17,6 +18,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.athenz.ApplicationActio import com.yahoo.vespa.hosted.controller.api.integration.athenz.HostedAthenzIdentities; import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzClientFactoryMock; import com.yahoo.vespa.hosted.controller.athenz.mock.AthenzDbMock; +import com.yahoo.vespa.hosted.controller.restapi.ApplicationRequestToDiscFilterRequestWrapper; import org.junit.Test; import java.io.IOException; @@ -35,13 +37,12 @@ import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * @author bjorncs */ public class ControllerAuthorizationFilterTest { + private static final ObjectMapper mapper = new ObjectMapper(); private static final AthenzUser USER = user("john"); @@ -157,11 +158,9 @@ public class ControllerAuthorizationFilterTest { } private static DiscFilterRequest createRequest(Method method, String path, AthenzIdentity identity) { - DiscFilterRequest request = mock(DiscFilterRequest.class); - when(request.getMethod()).thenReturn(method.name()); - when(request.getRequestURI()).thenReturn(path); - when(request.getUserPrincipal()).thenReturn(new AthenzPrincipal(identity)); - return request; + Request request = new Request(path, new byte[0], Request.Method.valueOf(method.name()), + new AthenzPrincipal(identity)); + return new ApplicationRequestToDiscFilterRequestWrapper(request); } private static String getErrorMessage(MockResponseHandler responseHandler) { @@ -185,4 +184,5 @@ public class ControllerAuthorizationFilterTest { this.message = message; } } + } |