aboutsummaryrefslogtreecommitdiffstats
path: root/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java')
-rw-r--r--container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java667
1 files changed, 667 insertions, 0 deletions
diff --git a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java
new file mode 100644
index 00000000000..a67656dd5ca
--- /dev/null
+++ b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java
@@ -0,0 +1,667 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.jdisc.http.server.jetty;
+
+import com.google.inject.AbstractModule;
+import com.google.inject.util.Modules;
+import com.yahoo.container.logging.ConnectionLog;
+import com.yahoo.container.logging.RequestLog;
+import com.yahoo.jdisc.AbstractResource;
+import com.yahoo.jdisc.Request;
+import com.yahoo.jdisc.ResourceReference;
+import com.yahoo.jdisc.Response;
+import com.yahoo.jdisc.handler.AbstractRequestHandler;
+import com.yahoo.jdisc.handler.CompletionHandler;
+import com.yahoo.jdisc.handler.ContentChannel;
+import com.yahoo.jdisc.handler.ResponseDispatch;
+import com.yahoo.jdisc.handler.ResponseHandler;
+import com.yahoo.jdisc.http.ConnectorConfig;
+import com.yahoo.jdisc.http.HttpRequest;
+import com.yahoo.jdisc.http.HttpResponse;
+import com.yahoo.jdisc.http.ServerConfig;
+import com.yahoo.jdisc.http.ServletPathsConfig;
+import com.yahoo.jdisc.http.filter.RequestFilter;
+import com.yahoo.jdisc.http.filter.ResponseFilter;
+import com.yahoo.jdisc.http.filter.ResponseHeaderFilter;
+import com.yahoo.jdisc.http.filter.chain.RequestFilterChain;
+import com.yahoo.jdisc.http.filter.chain.ResponseFilterChain;
+import com.yahoo.jdisc.http.guiceModules.ConnectorFactoryRegistryModule;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.isNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Oyvind Bakksjo
+ * @author bjorncs
+ */
+public class FilterTestCase {
+ @Test
+ public void requireThatRequestFilterIsNotRunOnUnboundPath() throws Exception {
+ RequestFilterMockBase filter = mock(RequestFilterMockBase.class);
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", filter)
+ .addRequestFilterBinding("my-request-filter", "http://*/filtered/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(filter, never()).filter(any(HttpRequest.class), any(ResponseHandler.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterIsRunOnBoundPath() throws Exception {
+ final RequestFilter filter = mock(RequestFilterMockBase.class);
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", filter)
+ .addRequestFilterBinding("my-request-filter", "http://*/filtered/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/filtered/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(filter, times(1)).filter(any(HttpRequest.class), any(ResponseHandler.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterChangesAreSeenByRequestHandler() throws Exception {
+ final RequestFilter filter = new HeaderRequestFilter("foo", "bar");
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", filter)
+ .addRequestFilterBinding("my-request-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ assertThat(requestHandler.getHeaderMap().get("foo").get(0), is("bar"));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterCanRespond() throws Exception {
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", new RespondForbiddenFilter())
+ .addRequestFilterBinding("my-request-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html").expectStatusCode(is(Response.Status.FORBIDDEN));
+
+ assertThat(requestHandler.hasBeenInvokedYet(), is(false));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatFilterCanHaveNullCompletionHandler() throws Exception {
+ final int responseStatus = Response.Status.OK;
+ final String responseMessage = "Excellent";
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", new NullCompletionHandlerFilter(responseStatus, responseMessage))
+ .addRequestFilterBinding("my-request-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html")
+ .expectStatusCode(is(responseStatus))
+ .expectContent(is(responseMessage));
+
+ assertThat(requestHandler.hasBeenInvokedYet(), is(false));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterExecutionIsExceptionSafe() throws Exception {
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", new ThrowingRequestFilter())
+ .addRequestFilterBinding("my-request-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html").expectStatusCode(is(Response.Status.INTERNAL_SERVER_ERROR));
+
+ assertThat(requestHandler.hasBeenInvokedYet(), is(false));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatResponseFilterIsNotRunOnUnboundPath() throws Exception {
+ final ResponseFilter filter = mock(ResponseFilterMockBase.class);
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addResponseFilter("my-response-filter", filter)
+ .addResponseFilterBinding("my-response-filter", "http://*/filtered/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(filter, never()).filter(any(Response.class), any(Request.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatResponseFilterIsRunOnBoundPath() throws Exception {
+ final ResponseFilter filter = mock(ResponseFilterMockBase.class);
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addResponseFilter("my-response-filter", filter)
+ .addResponseFilterBinding("my-response-filter", "http://*/filtered/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/filtered/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(filter, times(1)).filter(any(Response.class), any(Request.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatResponseFilterChangesAreWrittenToResponse() throws Exception {
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addResponseFilter("my-response-filter", new HeaderResponseFilter("foo", "bar"))
+ .addResponseFilterBinding("my-response-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html")
+ .expectHeader("foo", is("bar"));
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatResponseFilterExecutionIsExceptionSafe() throws Exception {
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addResponseFilter("my-response-filter", new ThrowingResponseFilter())
+ .addResponseFilterBinding("my-response-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html").expectStatusCode(is(Response.Status.INTERNAL_SERVER_ERROR));
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterAndResponseFilterCanBindToSamePath() throws Exception {
+ final RequestFilter requestFilter = mock(RequestFilterMockBase.class);
+ final ResponseFilter responseFilter = mock(ResponseFilterMockBase.class);
+ final String uriPattern = "http://*/*";
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", requestFilter)
+ .addRequestFilterBinding("my-request-filter", uriPattern)
+ .addResponseFilter("my-response-filter", responseFilter)
+ .addResponseFilterBinding("my-response-filter", uriPattern)
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(requestFilter, times(1)).filter(any(HttpRequest.class), any(ResponseHandler.class));
+ verify(responseFilter, times(1)).filter(any(Response.class), any(Request.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatResponseFromRequestFilterGoesThroughResponseFilter() throws Exception {
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", new RespondForbiddenFilter())
+ .addRequestFilterBinding("my-request-filter", "http://*/*")
+ .addResponseFilter("my-response-filter", new HeaderResponseFilter("foo", "bar"))
+ .addResponseFilterBinding("my-response-filter", "http://*/*")
+ .build();
+ final MyRequestHandler requestHandler = new MyRequestHandler();
+ final TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html")
+ .expectStatusCode(is(Response.Status.FORBIDDEN))
+ .expectHeader("foo", is("bar"));
+
+ assertThat(requestHandler.hasBeenInvokedYet(), is(false));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterChainRetainsFilters() {
+ final RequestFilter requestFilter1 = mock(RequestFilter.class);
+ final RequestFilter requestFilter2 = mock(RequestFilter.class);
+
+ verify(requestFilter1, never()).refer();
+ verify(requestFilter2, never()).refer();
+ final ResourceReference reference1 = mock(ResourceReference.class);
+ final ResourceReference reference2 = mock(ResourceReference.class);
+ when(requestFilter1.refer()).thenReturn(reference1);
+ when(requestFilter2.refer()).thenReturn(reference2);
+ final RequestFilter chain = RequestFilterChain.newInstance(requestFilter1, requestFilter2);
+ verify(requestFilter1, times(1)).refer();
+ verify(requestFilter2, times(1)).refer();
+
+ verify(reference1, never()).close();
+ verify(reference2, never()).close();
+ chain.release();
+ verify(reference1, times(1)).close();
+ verify(reference2, times(1)).close();
+ }
+
+ @Test
+ public void requireThatRequestFilterChainIsRun() throws Exception {
+ final RequestFilter requestFilter1 = mock(RequestFilter.class);
+ final RequestFilter requestFilter2 = mock(RequestFilter.class);
+ final RequestFilter requestFilterChain = RequestFilterChain.newInstance(requestFilter1, requestFilter2);
+ final HttpRequest request = null;
+ final ResponseHandler responseHandler = null;
+ requestFilterChain.filter(request, responseHandler);
+ verify(requestFilter1).filter(isNull(), any(ResponseHandler.class));
+ verify(requestFilter2).filter(isNull(), any(ResponseHandler.class));
+ }
+
+ @Test
+ public void requireThatRequestFilterChainCallsFilterWithOriginalRequest() throws Exception {
+ final RequestFilter requestFilter = mock(RequestFilter.class);
+ final RequestFilter requestFilterChain = RequestFilterChain.newInstance(requestFilter);
+ final HttpRequest request = mock(HttpRequest.class);
+ final ResponseHandler responseHandler = null;
+ requestFilterChain.filter(request, responseHandler);
+
+ // Check that the filter is called with the same request argument as the chain was,
+ // in a manner that allows the request object to be wrapped.
+ final ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ verify(requestFilter).filter(requestCaptor.capture(), isNull());
+ verify(request, never()).getUri();
+ requestCaptor.getValue().getUri();
+ verify(request, times(1)).getUri();
+ }
+
+ @Test
+ public void requireThatRequestFilterChainCallsFilterWithOriginalResponseHandler() throws Exception {
+ final RequestFilter requestFilter = mock(RequestFilter.class);
+ final RequestFilter requestFilterChain = RequestFilterChain.newInstance(requestFilter);
+ final HttpRequest request = null;
+ final ResponseHandler responseHandler = mock(ResponseHandler.class);
+ requestFilterChain.filter(request, responseHandler);
+
+ // Check that the filter is called with the same response handler argument as the chain was,
+ // in a manner that allows the handler object to be wrapped.
+ final ArgumentCaptor<ResponseHandler> responseHandlerCaptor = ArgumentCaptor.forClass(ResponseHandler.class);
+ verify(requestFilter).filter(isNull(), responseHandlerCaptor.capture());
+ verify(responseHandler, never()).handleResponse(any(Response.class));
+ responseHandlerCaptor.getValue().handleResponse(mock(Response.class));
+ verify(responseHandler, times(1)).handleResponse(any(Response.class));
+ }
+
+ @Test
+ public void requireThatRequestFilterCanTerminateChain() throws Exception {
+ final RequestFilter requestFilter1 = new RespondForbiddenFilter();
+ final RequestFilter requestFilter2 = mock(RequestFilter.class);
+ final RequestFilter requestFilterChain = RequestFilterChain.newInstance(requestFilter1, requestFilter2);
+ final HttpRequest request = null;
+ final ResponseHandler responseHandler = mock(ResponseHandler.class);
+ when(responseHandler.handleResponse(any(Response.class))).thenReturn(mock(ContentChannel.class));
+
+ requestFilterChain.filter(request, responseHandler);
+
+ verify(requestFilter2, never()).filter(any(HttpRequest.class), any(ResponseHandler.class));
+
+ final ArgumentCaptor<Response> responseCaptor = ArgumentCaptor.forClass(Response.class);
+ verify(responseHandler).handleResponse(responseCaptor.capture());
+ assertThat(responseCaptor.getValue().getStatus(), is(Response.Status.FORBIDDEN));
+ }
+
+ @Test
+ public void requireThatResponseFilterChainRetainsFilters() {
+ final ResponseFilter responseFilter1 = mock(ResponseFilter.class);
+ final ResponseFilter responseFilter2 = mock(ResponseFilter.class);
+
+ verify(responseFilter1, never()).refer();
+ verify(responseFilter2, never()).refer();
+ final ResourceReference reference1 = mock(ResourceReference.class);
+ final ResourceReference reference2 = mock(ResourceReference.class);
+ when(responseFilter1.refer()).thenReturn(reference1);
+ when(responseFilter2.refer()).thenReturn(reference2);
+ final ResponseFilter chain = ResponseFilterChain.newInstance(responseFilter1, responseFilter2);
+ verify(responseFilter1, times(1)).refer();
+ verify(responseFilter2, times(1)).refer();
+
+ verify(reference1, never()).close();
+ verify(reference2, never()).close();
+ chain.release();
+ verify(reference1, times(1)).close();
+ verify(reference2, times(1)).close();
+ }
+
+ @Test
+ public void requireThatResponseFilterChainIsRun() {
+ final ResponseFilter responseFilter1 = new ResponseHeaderFilter("foo", "bar");
+ final ResponseFilter responseFilter2 = mock(ResponseFilter.class);
+ final int statusCode = Response.Status.BAD_GATEWAY;
+ final Response response = new Response(statusCode);
+ final Request request = null;
+
+ ResponseFilterChain.newInstance(responseFilter1, responseFilter2).filter(response, request);
+
+ final ArgumentCaptor<Response> responseCaptor = ArgumentCaptor.forClass(Response.class);
+ verify(responseFilter2).filter(responseCaptor.capture(), isNull());
+ assertThat(responseCaptor.getValue().getStatus(), is(statusCode));
+ assertThat(responseCaptor.getValue().headers().getFirst("foo"), is("bar"));
+
+ assertThat(response.getStatus(), is(statusCode));
+ assertThat(response.headers().getFirst("foo"), is("bar"));
+ }
+
+ @Test
+ public void requireThatDefaultRequestFilterChainIsRunIfNoOtherFilterChainMatches() throws IOException, InterruptedException {
+ RequestFilter filterWithBinding = mock(RequestFilter.class);
+ RequestFilter defaultFilter = mock(RequestFilter.class);
+ String defaultFilterId = "default-request-filter";
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", filterWithBinding)
+ .addRequestFilterBinding("my-request-filter", "http://*/filtered/*")
+ .addRequestFilter(defaultFilterId, defaultFilter)
+ .setRequestFilterDefaultForPort(defaultFilterId, 0)
+ .build();
+ MyRequestHandler requestHandler = new MyRequestHandler();
+ TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(defaultFilter, times(1)).filter(any(HttpRequest.class), any(ResponseHandler.class));
+ verify(filterWithBinding, never()).filter(any(HttpRequest.class), any(ResponseHandler.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatDefaultResponseFilterChainIsRunIfNoOtherFilterChainMatches() throws IOException, InterruptedException {
+ ResponseFilter filterWithBinding = mock(ResponseFilter.class);
+ ResponseFilter defaultFilter = mock(ResponseFilter.class);
+ String defaultFilterId = "default-response-filter";
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addResponseFilter("my-response-filter", filterWithBinding)
+ .addResponseFilterBinding("my-response-filter", "http://*/filtered/*")
+ .addResponseFilter(defaultFilterId, defaultFilter)
+ .setResponseFilterDefaultForPort(defaultFilterId, 0)
+ .build();
+ MyRequestHandler requestHandler = new MyRequestHandler();
+ TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(defaultFilter, times(1)).filter(any(Response.class), any(Request.class));
+ verify(filterWithBinding, never()).filter(any(Response.class), any(Request.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatRequestFilterWithBindingMatchHasPrecedenceOverDefaultFilter() throws IOException, InterruptedException {
+ RequestFilterMockBase filterWithBinding = mock(RequestFilterMockBase.class);
+ RequestFilterMockBase defaultFilter = mock(RequestFilterMockBase.class);
+ String defaultFilterId = "default-request-filter";
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", filterWithBinding)
+ .addRequestFilterBinding("my-request-filter", "http://*/filtered/*")
+ .addRequestFilter(defaultFilterId, defaultFilter)
+ .setRequestFilterDefaultForPort(defaultFilterId, 0)
+ .build();
+ MyRequestHandler requestHandler = new MyRequestHandler();
+ TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/filtered/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(defaultFilter, never()).filter(any(HttpRequest.class), any(ResponseHandler.class));
+ verify(filterWithBinding).filter(any(HttpRequest.class), any(ResponseHandler.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatResponseFilterWithBindingMatchHasPrecedenceOverDefaultFilter() throws IOException, InterruptedException {
+ ResponseFilter filterWithBinding = mock(ResponseFilter.class);
+ ResponseFilter defaultFilter = mock(ResponseFilter.class);
+ String defaultFilterId = "default-response-filter";
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addResponseFilter("my-response-filter", filterWithBinding)
+ .addResponseFilterBinding("my-response-filter", "http://*/filtered/*")
+ .addResponseFilter(defaultFilterId, defaultFilter)
+ .setResponseFilterDefaultForPort(defaultFilterId, 0)
+ .build();
+ MyRequestHandler requestHandler = new MyRequestHandler();
+ TestDriver testDriver = newDriver(requestHandler, filterBindings);
+
+ testDriver.client().get("/filtered/status.html");
+
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(defaultFilter, never()).filter(any(Response.class), any(Request.class));
+ verify(filterWithBinding, times(1)).filter(any(Response.class), any(Request.class));
+
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatMetricAreReported() throws IOException, InterruptedException {
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", mock(RequestFilter.class))
+ .addRequestFilterBinding("my-request-filter", "http://*/*")
+ .build();
+ MetricConsumerMock metricConsumerMock = new MetricConsumerMock();
+ MyRequestHandler requestHandler = new MyRequestHandler();
+ TestDriver testDriver = newDriver(requestHandler, filterBindings, metricConsumerMock, false);
+
+ testDriver.client().get("/status.html");
+ assertThat(requestHandler.awaitInvocation(), is(true));
+ verify(metricConsumerMock.mockitoMock())
+ .add(MetricDefinitions.FILTERING_REQUEST_HANDLED, 1L, MetricConsumerMock.STATIC_CONTEXT);
+ verify(metricConsumerMock.mockitoMock(), never())
+ .add(MetricDefinitions.FILTERING_REQUEST_UNHANDLED, 1L, MetricConsumerMock.STATIC_CONTEXT);
+ verify(metricConsumerMock.mockitoMock(), never())
+ .add(MetricDefinitions.FILTERING_RESPONSE_HANDLED, 1L, MetricConsumerMock.STATIC_CONTEXT);
+ verify(metricConsumerMock.mockitoMock())
+ .add(MetricDefinitions.FILTERING_RESPONSE_UNHANDLED, 1L, MetricConsumerMock.STATIC_CONTEXT);
+ assertThat(testDriver.close(), is(true));
+ }
+
+ @Test
+ public void requireThatStrictFilteringRejectsRequestsNotMatchingFilterChains() throws IOException {
+ RequestFilter filter = mock(RequestFilter.class);
+ FilterBindings filterBindings = new FilterBindings.Builder()
+ .addRequestFilter("my-request-filter", filter)
+ .addRequestFilterBinding("my-request-filter", "http://*/filtered/*")
+ .build();
+ MyRequestHandler requestHandler = new MyRequestHandler();
+ TestDriver testDriver = newDriver(requestHandler, filterBindings, new MetricConsumerMock(), true);
+
+ testDriver.client().get("/unfiltered/")
+ .expectStatusCode(is(Response.Status.FORBIDDEN))
+ .expectContent(containsString("Request did not match any request filter chain"));
+ verify(filter, never()).filter(any(), any());
+ assertThat(testDriver.close(), is(true));
+ }
+
+ private static TestDriver newDriver(MyRequestHandler requestHandler, FilterBindings filterBindings) {
+ return newDriver(requestHandler, filterBindings, new MetricConsumerMock(), false);
+ }
+
+ private static TestDriver newDriver(
+ MyRequestHandler requestHandler,
+ FilterBindings filterBindings,
+ MetricConsumerMock metricConsumer,
+ boolean strictFiltering) {
+ return TestDriver.newInstance(
+ JettyHttpServer.class,
+ requestHandler,
+ newFilterModule(filterBindings, metricConsumer, strictFiltering));
+ }
+
+ private static com.google.inject.Module newFilterModule(
+ FilterBindings filterBindings, MetricConsumerMock metricConsumer, boolean strictFiltering) {
+ return Modules.combine(
+ new AbstractModule() {
+ @Override
+ protected void configure() {
+
+ bind(FilterBindings.class).toInstance(filterBindings);
+ bind(ServerConfig.class).toInstance(new ServerConfig(new ServerConfig.Builder().strictFiltering(strictFiltering)));
+ bind(ConnectorConfig.class).toInstance(new ConnectorConfig(new ConnectorConfig.Builder()));
+ bind(ServletPathsConfig.class).toInstance(new ServletPathsConfig(new ServletPathsConfig.Builder()));
+ bind(ConnectionLog.class).toInstance(new VoidConnectionLog());
+ bind(RequestLog.class).toInstance(new VoidRequestLog());
+ }
+ },
+ new ConnectorFactoryRegistryModule(),
+ metricConsumer.asGuiceModule());
+ }
+
+ private static abstract class RequestFilterMockBase extends AbstractResource implements RequestFilter {}
+ private static abstract class ResponseFilterMockBase extends AbstractResource implements ResponseFilter {}
+
+ private static class MyRequestHandler extends AbstractRequestHandler {
+ private final CountDownLatch invocationLatch = new CountDownLatch(1);
+ private final AtomicReference<Map<String, List<String>>> headerCopy = new AtomicReference<>(null);
+
+ @Override
+ public ContentChannel handleRequest(final Request request, final ResponseHandler handler) {
+ try {
+ headerCopy.set(new HashMap<String, List<String>>(request.headers()));
+ ResponseDispatch.newInstance(Response.Status.OK).dispatch(handler);
+ return null;
+ } finally {
+ invocationLatch.countDown();
+ }
+ }
+
+ public boolean hasBeenInvokedYet() {
+ return invocationLatch.getCount() == 0L;
+ }
+
+ public boolean awaitInvocation() throws InterruptedException {
+ return invocationLatch.await(60, TimeUnit.SECONDS);
+ }
+
+ public Map<String, List<String>> getHeaderMap() {
+ return headerCopy.get();
+ }
+ }
+
+ private static class RespondForbiddenFilter extends AbstractResource implements RequestFilter {
+ @Override
+ public void filter(final HttpRequest request, final ResponseHandler handler) {
+ ResponseDispatch.newInstance(Response.Status.FORBIDDEN).dispatch(handler);
+ }
+ }
+
+ private static class ThrowingRequestFilter extends AbstractResource implements RequestFilter {
+ @Override
+ public void filter(final HttpRequest request, final ResponseHandler handler) {
+ throw new RuntimeException();
+ }
+ }
+
+ private static class ThrowingResponseFilter extends AbstractResource implements ResponseFilter {
+ @Override
+ public void filter(final Response response, final Request request) {
+ throw new RuntimeException();
+ }
+ }
+
+ private static class HeaderRequestFilter extends AbstractResource implements RequestFilter {
+ private final String key;
+ private final String val;
+
+ public HeaderRequestFilter(final String key, final String val) {
+ this.key = key;
+ this.val = val;
+ }
+
+ @Override
+ public void filter(final HttpRequest request, final ResponseHandler handler) {
+ request.headers().add(key, val);
+ }
+ }
+
+ private static class HeaderResponseFilter extends AbstractResource implements ResponseFilter {
+ private final String key;
+ private final String val;
+
+ public HeaderResponseFilter(final String key, final String val) {
+ this.key = key;
+ this.val = val;
+ }
+
+ @Override
+ public void filter(final Response response, final Request request) {
+ response.headers().add(key, val);
+ }
+ }
+
+ public class NullCompletionHandlerFilter extends AbstractResource implements RequestFilter {
+ private final int responseStatus;
+ private final String responseMessage;
+
+ public NullCompletionHandlerFilter(final int responseStatus, final String responseMessage) {
+ this.responseStatus = responseStatus;
+ this.responseMessage = responseMessage;
+ }
+
+ @Override
+ public void filter(final HttpRequest request, final ResponseHandler responseHandler) {
+ final HttpResponse response = HttpResponse.newInstance(responseStatus);
+ final ContentChannel channel = responseHandler.handleResponse(response);
+ final CompletionHandler completionHandler = null;
+ channel.write(ByteBuffer.wrap(responseMessage.getBytes()), completionHandler);
+ channel.close(null);
+ }
+ }
+}