diff options
author | Bjørn Christian Seime <bjorncs@verizonmedia.com> | 2020-11-16 11:19:22 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@verizonmedia.com> | 2020-11-17 13:28:07 +0100 |
commit | cdbf498c9182ac06bb21ce8ae7b745dd8c8f3c83 (patch) | |
tree | 21a21f94ef0d888d74597b6142956c1b41901b13 /jdisc_http_service | |
parent | e4c14623ad4ecbe6337a49d2176621c528bf7c22 (diff) |
Support default request/response filters per connector
Filter requests using default request/response filter if no other filters matches the request uri.
Diffstat (limited to 'jdisc_http_service')
10 files changed, 290 insertions, 100 deletions
diff --git a/jdisc_http_service/abi-spec.json b/jdisc_http_service/abi-spec.json index c19ce5ce83c..8bf7f30964a 100644 --- a/jdisc_http_service/abi-spec.json +++ b/jdisc_http_service/abi-spec.json @@ -740,6 +740,8 @@ "public com.yahoo.jdisc.http.ServerConfig$Builder removeRawPostBodyForWwwUrlEncodedPost(boolean)", "public com.yahoo.jdisc.http.ServerConfig$Builder filter(com.yahoo.jdisc.http.ServerConfig$Filter$Builder)", "public com.yahoo.jdisc.http.ServerConfig$Builder filter(java.util.List)", + "public com.yahoo.jdisc.http.ServerConfig$Builder defaultFilters(com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder)", + "public com.yahoo.jdisc.http.ServerConfig$Builder defaultFilters(java.util.List)", "public com.yahoo.jdisc.http.ServerConfig$Builder maxWorkerThreads(int)", "public com.yahoo.jdisc.http.ServerConfig$Builder minWorkerThreads(int)", "public com.yahoo.jdisc.http.ServerConfig$Builder stopTimeout(double)", @@ -754,11 +756,43 @@ ], "fields": [ "public java.util.List filter", + "public java.util.List defaultFilters", "public com.yahoo.jdisc.http.ServerConfig$Jmx$Builder jmx", "public com.yahoo.jdisc.http.ServerConfig$Metric$Builder metric", "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder accessLog" ] }, + "com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.config.ConfigBuilder" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()", + "public void <init>(com.yahoo.jdisc.http.ServerConfig$DefaultFilters)", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder filterId(java.lang.String)", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder localPort(int)", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters build()" + ], + "fields": [] + }, + "com.yahoo.jdisc.http.ServerConfig$DefaultFilters": { + "superClass": "com.yahoo.config.InnerNode", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void <init>(com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder)", + "public java.lang.String filterId()", + "public int localPort()" + ], + "fields": [] + }, "com.yahoo.jdisc.http.ServerConfig$Filter$Builder": { "superClass": "java.lang.Object", "interfaces": [ @@ -894,6 +928,8 @@ "public boolean removeRawPostBodyForWwwUrlEncodedPost()", "public java.util.List filter()", "public com.yahoo.jdisc.http.ServerConfig$Filter filter(int)", + "public java.util.List defaultFilters()", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters defaultFilters(int)", "public int maxWorkerThreads()", "public int minWorkerThreads()", "public double stopTimeout()", diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java index 301c92a4583..310f3c9a646 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java @@ -3,18 +3,15 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.jdisc.application.BindingRepository; import com.yahoo.jdisc.application.BindingSet; -import com.yahoo.jdisc.application.UriPattern; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; import java.net.URI; import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.Optional; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -import static java.util.stream.Collectors.toSet; +import java.util.TreeMap; /** * Resolves request/response filter (chain) from a {@link URI} instance. @@ -24,87 +21,82 @@ import static java.util.stream.Collectors.toSet; */ public class FilterBindings { - private final BindingSet<FilterHolder<RequestFilter>> requestFilters; - private final BindingSet<FilterHolder<ResponseFilter>> responseFilters; + private final Map<String, RequestFilter> requestFilters; + private final Map<String, ResponseFilter> responseFilters; + private final Map<Integer, String> defaultRequestFilters; + private final Map<Integer, String> defaultResponseFilters; + private final BindingSet<String> requestFilterBindings; + private final BindingSet<String> responseFilterBindings; private FilterBindings( - BindingSet<FilterHolder<RequestFilter>> requestFilters, - BindingSet<FilterHolder<ResponseFilter>> responseFilters) { + Map<String, RequestFilter> requestFilters, + Map<String, ResponseFilter> responseFilters, + Map<Integer, String> defaultRequestFilters, + Map<Integer, String> defaultResponseFilters, + BindingSet<String> requestFilterBindings, + BindingSet<String> responseFilterBindings) { this.requestFilters = requestFilters; this.responseFilters = responseFilters; + this.defaultRequestFilters = defaultRequestFilters; + this.defaultResponseFilters = defaultResponseFilters; + this.requestFilterBindings = requestFilterBindings; + this.responseFilterBindings = responseFilterBindings; } - public Optional<String> resolveRequestFilter(URI uri) { return resolveFilterId(requestFilters, uri); } - - public Optional<String> resolveResponseFilter(URI uri) { return resolveFilterId(responseFilters, uri); } - - public RequestFilter getRequestFilter(String filterId) { return getFilterInstance(requestFilters, filterId); } - - public ResponseFilter getResponseFilter(String filterId) { return getFilterInstance(responseFilters, filterId); } - - public Collection<String> requestFilterIds() { return filterIds(requestFilters); } - - public Collection<String> responseFilterIds() { return filterIds(responseFilters); } + public Optional<String> resolveRequestFilter(URI uri, int localPort) { + String filterId = requestFilterBindings.resolve(uri); + if (filterId != null) return Optional.of(filterId); + return Optional.ofNullable(defaultRequestFilters.get(localPort)); + } - public Collection<RequestFilter> requestFilters() { return filters(requestFilters); } + public Optional<String> resolveResponseFilter(URI uri, int localPort) { + String filterId = responseFilterBindings.resolve(uri); + if (filterId != null) return Optional.of(filterId); + return Optional.ofNullable(defaultResponseFilters.get(localPort)); + } - public Collection<ResponseFilter> responseFilters() { return filters(responseFilters); } + public RequestFilter getRequestFilter(String filterId) { return requestFilters.get(filterId); } - private static <T> Optional<String> resolveFilterId(BindingSet<FilterHolder<T>> filters, URI uri) { - return Optional.ofNullable(filters.resolve(uri)) - .map(holder -> holder.filterId); - } + public ResponseFilter getResponseFilter(String filterId) { return responseFilters.get(filterId); } - private static <T> T getFilterInstance(BindingSet<FilterHolder<T>> filters, String filterId) { - return stream(filters) - .filter(filterEntry -> filterId.equals(filterEntry.getValue().filterId)) - .map(filterEntry -> filterEntry.getValue().filterInstance) - .findAny() - .orElseThrow(() -> new IllegalArgumentException("No filter with id " + filterId)); - } + public Collection<String> requestFilterIds() { return requestFilters.keySet(); } - private static <T> Collection<String> filterIds(BindingSet<FilterHolder<T>> filters) { - return stream(filters) - .map(filterEntry -> filterEntry.getValue().filterId) - .collect(toSet()); - } + public Collection<String> responseFilterIds() { return responseFilters.keySet(); } - private static <T> Collection<T> filters(BindingSet<FilterHolder<T>> filters) { - return stream(filters) - .map(filterEntry -> filterEntry.getValue().filterInstance) - .collect(toSet()); - } + public Collection<RequestFilter> requestFilters() { return requestFilters.values(); } - private static <T> Stream<Map.Entry<UriPattern, FilterHolder<T>>> stream(BindingSet<FilterHolder<T>> filters) { - return StreamSupport.stream(filters.spliterator(), false); - } + public Collection<ResponseFilter> responseFilters() { return responseFilters.values(); } public static class Builder { - private final BindingRepository<FilterHolder<RequestFilter>> requestFilters = new BindingRepository<>(); - private final BindingRepository<FilterHolder<ResponseFilter>> responseFilters = new BindingRepository<>(); + private final Map<String, RequestFilter> requestFilters = new TreeMap<>(); + private final Map<String, ResponseFilter> responseFilters = new TreeMap<>(); + private final Map<Integer, String> defaultRequestFilters = new TreeMap<>(); + private final Map<Integer, String> defaultResponseFilters = new TreeMap<>(); + private final BindingRepository<String> requestFilterBindings = new BindingRepository<>(); + private final BindingRepository<String> responseFilterBindings = new BindingRepository<>(); public Builder() {} - public Builder addRequestFilter(String id, String binding, RequestFilter filter) { - requestFilters.bind(binding, new FilterHolder<>(id, filter)); - return this; - } + public Builder addRequestFilter(String id, RequestFilter filter) { requestFilters.put(id, filter); return this; } - public Builder addResponseFilter(String id, String binding, ResponseFilter filter) { - responseFilters.bind(binding, new FilterHolder<>(id, filter)); - return this; - } + public Builder addResponseFilter(String id, ResponseFilter filter) { responseFilters.put(id, filter); return this; } - public FilterBindings build() { return new FilterBindings(requestFilters.activate(), responseFilters.activate()); } - } + public Builder addRequestFilterBinding(String id, String binding) { requestFilterBindings.bind(binding, id); return this; } + + public Builder addResponseFilterBinding(String id, String binding) { responseFilterBindings.bind(binding, id); return this; } + + public Builder setRequestFilterDefaultForPort(String id, int port) { defaultRequestFilters.put(port, id); return this; } - private static class FilterHolder<T> { - final String filterId; - final T filterInstance; + public Builder setResponseFilterDefaultForPort(String id, int port) { defaultResponseFilters.put(port, id); return this; } - FilterHolder(String filterId, T filterInstance) { - this.filterId = filterId; - this.filterInstance = filterInstance; + public FilterBindings build() { + return new FilterBindings( + Collections.unmodifiableMap(requestFilters), + Collections.unmodifiableMap(responseFilters), + Collections.unmodifiableMap(defaultRequestFilters), + Collections.unmodifiableMap(defaultResponseFilters), + requestFilterBindings.activate(), + responseFilterBindings.activate()); } } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java new file mode 100644 index 00000000000..badb0e736ae --- /dev/null +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java @@ -0,0 +1,35 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.server.jetty; + +import com.yahoo.jdisc.http.filter.RequestFilter; +import com.yahoo.jdisc.http.filter.ResponseFilter; + +import javax.servlet.http.HttpServletRequest; +import java.net.URI; +import java.util.Optional; + +import static com.yahoo.jdisc.http.server.jetty.JDiscHttpServlet.getConnector; + +/** + * Resolve request/response filter (chain) based on {@link FilterBindings}. + * + * @author bjorncs + */ +class FilterResolver { + + private final FilterBindings bindings; + + FilterResolver(FilterBindings bindings) { + this.bindings = bindings; + } + + Optional<RequestFilter> resolveRequestFilter(HttpServletRequest servletRequest, URI jdiscUri) { + return bindings.resolveRequestFilter(jdiscUri, getConnector(servletRequest).listenPort()) + .map(bindings::getRequestFilter); + } + + Optional<ResponseFilter> resolveResponseFilter(HttpServletRequest servletRequest, URI jdiscUri) { + return bindings.resolveResponseFilter(jdiscUri, getConnector(servletRequest).listenPort()) + .map(bindings::getResponseFilter); + } +} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java index 7f761d6ab4a..0923e6688f4 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java @@ -16,6 +16,7 @@ import com.yahoo.jdisc.http.core.CompletionHandlers; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; +import javax.servlet.http.HttpServletRequest; import java.nio.ByteBuffer; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -41,10 +42,12 @@ class FilteringRequestHandler extends AbstractRequestHandler { }; - private final FilterBindings filterBindings; + private final FilterResolver filterResolver; + private final HttpServletRequest servletRequest; - public FilteringRequestHandler(FilterBindings filterBindings) { - this.filterBindings = filterBindings; + public FilteringRequestHandler(FilterResolver filterResolver, HttpServletRequest servletRequest) { + this.filterResolver = filterResolver; + this.servletRequest = servletRequest; } @Override @@ -52,12 +55,11 @@ class FilteringRequestHandler extends AbstractRequestHandler { Preconditions.checkArgument(request instanceof HttpRequest, "Expected HttpRequest, got " + request); Objects.requireNonNull(originalResponseHandler, "responseHandler"); - RequestFilter requestFilter = filterBindings.resolveRequestFilter(request.getUri()) - .map(filterBindings::getRequestFilter) + RequestFilter requestFilter = filterResolver.resolveRequestFilter(servletRequest, request.getUri()) .orElse(null); - ResponseFilter responseFilter = filterBindings.resolveResponseFilter(request.getUri()) - .map(filterBindings::getResponseFilter) + ResponseFilter responseFilter = filterResolver.resolveResponseFilter(servletRequest, request.getUri()) .orElse(null); + // Not using request.connect() here - it adds logic for error handling that we'd rather leave to the framework. RequestHandler resolvedRequestHandler = request.container().resolveHandler(request); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java index ed5095fb06f..940009e7520 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java @@ -212,7 +212,7 @@ class HttpRequestDispatch { AccessLogEntry accessLogEntry, HttpServletRequest servletRequest) { RequestHandler requestHandler = wrapHandlerIfFormPost( - new FilteringRequestHandler(context.filterBindings), + new FilteringRequestHandler(context.filterResolver, servletRequest), servletRequest, context.serverConfig.removeRawPostBodyForWwwUrlEncodedPost()); return new AccessLoggingRequestHandler(requestHandler, accessLogEntry); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java index 4667ff3975b..a308149beb5 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java @@ -8,7 +8,7 @@ import com.yahoo.jdisc.service.CurrentContainer; import java.util.concurrent.Executor; public class JDiscContext { - final FilterBindings filterBindings; + final FilterResolver filterResolver; final CurrentContainer container; final Executor janitor; final Metric metric; @@ -20,7 +20,7 @@ public class JDiscContext { Metric metric, ServerConfig serverConfig) { - this.filterBindings = filterBindings; + this.filterResolver = new FilterResolver(filterBindings); this.container = container; this.janitor = janitor; this.metric = metric; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java index 5a904299e44..f046ccd5439 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java @@ -3,6 +3,7 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.jdisc.handler.ResponseHandler; +import com.yahoo.jdisc.http.filter.RequestFilter; import javax.servlet.AsyncContext; import javax.servlet.AsyncListener; @@ -75,8 +76,7 @@ class JDiscFilterInvokerFilter implements Filter { private void runChainAndResponseFilters(URI uri, HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { Optional<OneTimeRunnable> responseFilterInvoker = - jDiscContext.filterBindings.resolveResponseFilter(uri) - .map(jDiscContext.filterBindings::getResponseFilter) + jDiscContext.filterResolver.resolveResponseFilter(request, uri) .map(responseFilter -> new OneTimeRunnable(() -> filterInvoker.invokeResponseFilterChain(responseFilter, uri, request, response))); @@ -106,12 +106,12 @@ class JDiscFilterInvokerFilter implements Filter { private HttpServletRequest runRequestFilterWithMatchingBinding(AtomicReference<Boolean> responseReturned, URI uri, HttpServletRequest request, HttpServletResponse response) throws IOException { try { - String requestFilterId = jDiscContext.filterBindings.resolveRequestFilter(uri).orElse(null); - if (requestFilterId == null) + RequestFilter requestFilter = jDiscContext.filterResolver.resolveRequestFilter(request, uri).orElse(null); + if (requestFilter == null) return request; ResponseHandler responseHandler = createResponseHandler(responseReturned, request, response); - return filterInvoker.invokeRequestFilterChain(jDiscContext.filterBindings.getRequestFilter(requestFilterId), uri, request, responseHandler); + return filterInvoker.invokeRequestFilterChain(requestFilter, uri, request, responseHandler); } catch (Exception e) { throw new RuntimeException("Failed running request filter chain for uri " + uri, e); } diff --git a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def index 81f9859fabd..f33dc35ea0b 100644 --- a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def +++ b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def @@ -27,6 +27,12 @@ filter[].id string # The binding of a filter filter[].binding string +# Filter id for a default filter (chain) +defaultFilters[].filterId string + +# The local port which the default filter should be applied to +defaultFilters[].localPort int + # Max number of threads in underlying Jetty pool maxWorkerThreads int default = 200 diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java index a978e42f7cb..b2e28c6af67 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java @@ -46,14 +46,16 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** - * @author bakksjo + * @author Oyvind Bakksjo + * @author bjorncs */ public class FilterTestCase { @Test public void requireThatRequestFilterIsNotRunOnUnboundPath() throws Exception { RequestFilterMockBase filter = mock(RequestFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/filtered/*", filter) + .addRequestFilter("my-request-filter", filter) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -70,7 +72,8 @@ public class FilterTestCase { public void requireThatRequestFilterIsRunOnBoundPath() throws Exception { final RequestFilter filter = mock(RequestFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/filtered/*", filter) + .addRequestFilter("my-request-filter", filter) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -87,7 +90,8 @@ public class FilterTestCase { public void requireThatRequestFilterChangesAreSeenByRequestHandler() throws Exception { final RequestFilter filter = new HeaderRequestFilter("foo", "bar"); FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", filter) + .addRequestFilter("my-request-filter", filter) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -103,7 +107,8 @@ public class FilterTestCase { @Test public void requireThatRequestFilterCanRespond() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new RespondForbiddenFilter()) + .addRequestFilter("my-request-filter", new RespondForbiddenFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -120,7 +125,8 @@ public class FilterTestCase { final int responseStatus = Response.Status.OK; final String responseMessage = "Excellent"; FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new NullCompletionHandlerFilter(responseStatus, responseMessage)) + .addRequestFilter("my-request-filter", new NullCompletionHandlerFilter(responseStatus, responseMessage)) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -137,7 +143,8 @@ public class FilterTestCase { @Test public void requireThatRequestFilterExecutionIsExceptionSafe() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new ThrowingRequestFilter()) + .addRequestFilter("my-request-filter", new ThrowingRequestFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -153,7 +160,8 @@ public class FilterTestCase { public void requireThatResponseFilterIsNotRunOnUnboundPath() throws Exception { final ResponseFilter filter = mock(ResponseFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/filtered/*", filter) + .addResponseFilter("my-response-filter", filter) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -170,7 +178,8 @@ public class FilterTestCase { public void requireThatResponseFilterIsRunOnBoundPath() throws Exception { final ResponseFilter filter = mock(ResponseFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/filtered/*", filter) + .addResponseFilter("my-response-filter", filter) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -186,7 +195,8 @@ public class FilterTestCase { @Test public void requireThatResponseFilterChangesAreWrittenToResponse() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/*", new HeaderResponseFilter("foo", "bar")) + .addResponseFilter("my-response-filter", new HeaderResponseFilter("foo", "bar")) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -202,7 +212,8 @@ public class FilterTestCase { @Test public void requireThatResponseFilterExecutionIsExceptionSafe() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/*", new ThrowingResponseFilter()) + .addResponseFilter("my-response-filter", new ThrowingResponseFilter()) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -220,8 +231,10 @@ public class FilterTestCase { final ResponseFilter responseFilter = mock(ResponseFilterMockBase.class); final String uriPattern = "http://*/*"; FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", uriPattern, requestFilter) - .addResponseFilter("my-response-filter", uriPattern, responseFilter) + .addRequestFilter("my-request-filter", requestFilter) + .addRequestFilterBinding("my-request-filter", uriPattern) + .addResponseFilter("my-response-filter", responseFilter) + .addResponseFilterBinding("my-response-filter", uriPattern) .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -238,8 +251,10 @@ public class FilterTestCase { @Test public void requireThatResponseFromRequestFilterGoesThroughResponseFilter() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new RespondForbiddenFilter()) - .addResponseFilter("my-response-filter", "http://*/*", new HeaderResponseFilter("foo", "bar")) + .addRequestFilter("my-request-filter", new RespondForbiddenFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") + .addResponseFilter("my-response-filter", new HeaderResponseFilter("foo", "bar")) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -380,10 +395,111 @@ public class FilterTestCase { assertThat(response.headers().getFirst("foo"), is("bar")); } - private static TestDriver newDriver( - final MyRequestHandler requestHandler, - FilterBindings filterBindings) - throws IOException { + @Test + public void requireThatDefaultRequestFilterChainIsRunIfNoOtherFilterChainMatches() throws IOException, InterruptedException { + RequestFilter filterWithBinding = mock(RequestFilter.class); + RequestFilter defaultFilter = mock(RequestFilter.class); + String defaultFilterId = "default-request-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addRequestFilter("my-request-filter", filterWithBinding) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") + .addRequestFilter(defaultFilterId, defaultFilter) + .setRequestFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, times(1)).filter(any(HttpRequest.class), any(ResponseHandler.class)); + verify(filterWithBinding, never()).filter(any(HttpRequest.class), any(ResponseHandler.class)); + + assertThat(testDriver.close(), is(true)); + } + + @Test + public void requireThatDefaultResponseFilterChainIsRunIfNoOtherFilterChainMatches() throws IOException, InterruptedException { + ResponseFilter filterWithBinding = mock(ResponseFilter.class); + ResponseFilter defaultFilter = mock(ResponseFilter.class); + String defaultFilterId = "default-response-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addResponseFilter("my-response-filter", filterWithBinding) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") + .addResponseFilter(defaultFilterId, defaultFilter) + .setResponseFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, times(1)).filter(any(Response.class), any(Request.class)); + verify(filterWithBinding, never()).filter(any(Response.class), any(Request.class)); + + assertThat(testDriver.close(), is(true)); + } + + @Test + public void requireThatRequestFilterWithBindingMatchHasPrecedenceOverDefaultFilter() throws IOException, InterruptedException { + RequestFilterMockBase filterWithBinding = mock(RequestFilterMockBase.class); + RequestFilterMockBase defaultFilter = mock(RequestFilterMockBase.class); + String defaultFilterId = "default-request-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addRequestFilter("my-request-filter", filterWithBinding) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") + .addRequestFilter(defaultFilterId, defaultFilter) + .setRequestFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/filtered/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, never()).filter(any(HttpRequest.class), any(ResponseHandler.class)); + verify(filterWithBinding).filter(any(HttpRequest.class), any(ResponseHandler.class)); + + assertThat(testDriver.close(), is(true)); + } + + @Test + public void requireThatResponsFilterWithBindingMatchHasPrecedenceOverDefaultFilter() throws IOException, InterruptedException { + ResponseFilter filterWithBinding = mock(ResponseFilter.class); + ResponseFilter defaultFilter = mock(ResponseFilter.class); + String defaultFilterId = "default-response-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addResponseFilter("my-response-filter", filterWithBinding) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") + .addResponseFilter(defaultFilterId, defaultFilter) + .setResponseFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/filtered/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, never()).filter(any(Response.class), any(Request.class)); + verify(filterWithBinding, times(1)).filter(any(Response.class), any(Request.class)); + + assertThat(testDriver.close(), is(true)); + } + + private static TestDriver newDriver(MyRequestHandler requestHandler, FilterBindings filterBindings) { return TestDriver.newInstance( JettyHttpServer.class, requestHandler, diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java index 272d6fbb66c..16969a47b84 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java @@ -31,6 +31,7 @@ import static org.hamcrest.CoreMatchers.is; /** * @author Tony Vaagenes + * @author bjorncs */ public class JDiscFilterForServletTest extends ServletTestBase { @Test @@ -79,14 +80,16 @@ public class JDiscFilterForServletTest extends ServletTestBase { private TestDriver requestFilterTestDriver() throws IOException { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new TestRequestFilter()) + .addRequestFilter("my-request-filter", new TestRequestFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); return TestDrivers.newInstance(dummyRequestHandler, bindings(filterBindings)); } private TestDriver responseFilterTestDriver() throws IOException { FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/*", new TestResponseFilter()) + .addResponseFilter("my-response-filter", new TestResponseFilter()) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); return TestDrivers.newInstance(dummyRequestHandler, bindings(filterBindings)); } |