diff options
author | Bjørn Christian Seime <bjorncs@verizonmedia.com> | 2020-11-16 11:19:22 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@verizonmedia.com> | 2020-11-17 13:28:07 +0100 |
commit | cdbf498c9182ac06bb21ce8ae7b745dd8c8f3c83 (patch) | |
tree | 21a21f94ef0d888d74597b6142956c1b41901b13 | |
parent | e4c14623ad4ecbe6337a49d2176621c528bf7c22 (diff) |
Support default request/response filters per connector
Filter requests using default request/response filter if no other filters matches the request uri.
12 files changed, 374 insertions, 213 deletions
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/FilterBindingsProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/FilterBindingsProvider.java index 195aee93246..6527a368113 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/FilterBindingsProvider.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/FilterBindingsProvider.java @@ -1,16 +1,21 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.jdisc; +import com.google.inject.Inject; import com.yahoo.component.ComponentId; +import com.yahoo.component.ComponentSpecification; import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.container.di.componentgraph.Provider; import com.yahoo.container.http.filter.FilterChainRepository; import com.yahoo.jdisc.http.ServerConfig; +import com.yahoo.jdisc.http.filter.RequestFilter; +import com.yahoo.jdisc.http.filter.ResponseFilter; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; +import com.yahoo.jdisc.http.filter.SecurityRequestFilterChain; import com.yahoo.jdisc.http.server.jetty.FilterBindings; -import java.util.ArrayList; -import java.util.List; +import java.util.HashSet; +import java.util.Set; /** * Provides filter bindings based on vespa config. @@ -20,35 +25,100 @@ import java.util.List; */ public class FilterBindingsProvider implements Provider<FilterBindings> { + private static final ComponentId SEARCH_SERVER_COMPONENT_ID = ComponentId.fromString("SearchServer"); + private final FilterBindings filterBindings; + @Inject public FilterBindingsProvider(ComponentId componentId, ServerConfig config, FilterChainRepository filterChainRepository, ComponentRegistry<SecurityRequestFilter> legacyRequestFilters) { try { - this.filterBindings = FilterUtil.setupFilters( - componentId, - legacyRequestFilters, - toFilterSpecs(config.filter()), - filterChainRepository); + FilterBindings.Builder builder = new FilterBindings.Builder(); + configureLegacyFilters(builder, componentId, legacyRequestFilters); + configureFilters(builder, config, filterChainRepository); + this.filterBindings = builder.build(); } catch (Exception e) { throw new RuntimeException( "Invalid config for http server '" + componentId.getNamespace() + "': " + e.getMessage(), e); } } - private List<FilterUtil.FilterSpec> toFilterSpecs(List<ServerConfig.Filter> inFilters) { - List<FilterUtil.FilterSpec> outFilters = new ArrayList<>(); - for (ServerConfig.Filter inFilter : inFilters) { - outFilters.add(new FilterUtil.FilterSpec(inFilter.id(), inFilter.binding())); + // TODO(gjoranv): remove + private static void configureLegacyFilters( + FilterBindings.Builder builder, + ComponentId id, + ComponentRegistry<SecurityRequestFilter> legacyRequestFilters) { + ComponentId serverName = id.getNamespace(); + if (SEARCH_SERVER_COMPONENT_ID.equals(serverName) && !legacyRequestFilters.allComponents().isEmpty()) { + String filterId = "legacy-filters"; + builder.addRequestFilter(filterId, SecurityRequestFilterChain.newInstance(legacyRequestFilters.allComponents())); + builder.addRequestFilterBinding(filterId, "http://*/*"); + } + } + + private static void configureFilters( + FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) { + addFilterInstances(builder, config, filterRepository); + addFilterBindings(builder, config, filterRepository); + addPortDefaultFilters(builder, config, filterRepository); + } + + private static void addFilterInstances( + FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) { + Set<String> filterIds = new HashSet<>(); + config.filter().forEach(filterBinding -> filterIds.add(filterBinding.id())); + config.defaultFilters().forEach(defaultFilter -> filterIds.add(defaultFilter.filterId())); + for (String filterId : filterIds) { + Object filterInstance = getFilterInstance(filterRepository, filterId); + if (filterInstance instanceof RequestFilter && filterInstance instanceof ResponseFilter) { + throw new IllegalArgumentException("The filter " + filterInstance.getClass().getName() + + " is unsupported since it's both a RequestFilter and a ResponseFilter."); + } else if (filterInstance instanceof RequestFilter) { + builder.addRequestFilter(filterId, (RequestFilter)filterInstance); + } else if (filterInstance instanceof ResponseFilter) { + builder.addResponseFilter(filterId, (ResponseFilter)filterInstance); + } else if (filterInstance == null) { + throw new IllegalArgumentException("No http filter with id " + filterId); + } else { + throw new IllegalArgumentException("Unknown filter type: " + filterInstance.getClass().getName()); + } + } + } + + private static void addFilterBindings( + FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) { + for (ServerConfig.Filter filterBinding : config.filter()) { + if (isRequestFilter(filterRepository, filterBinding.id())) { + builder.addRequestFilterBinding(filterBinding.id(), filterBinding.binding()); + } else { + builder.addResponseFilterBinding(filterBinding.id(), filterBinding.binding()); + } + } + } + + private static void addPortDefaultFilters( + FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) { + for (ServerConfig.DefaultFilters defaultFilter : config.defaultFilters()) { + if (isRequestFilter(filterRepository, defaultFilter.filterId())) { + builder.setRequestFilterDefaultForPort(defaultFilter.filterId(), defaultFilter.localPort()); + } else { + builder.setResponseFilterDefaultForPort(defaultFilter.filterId(), defaultFilter.localPort()); + } } - return outFilters; + } + + private static boolean isRequestFilter(FilterChainRepository filterRepository, String filterId) { + return getFilterInstance(filterRepository, filterId) instanceof RequestFilter; + } + + private static Object getFilterInstance(FilterChainRepository filterRepository, String filterId) { + return filterRepository.getFilter(ComponentSpecification.fromString(filterId)); } @Override public FilterBindings get() { return filterBindings; } - @Override - public void deconstruct() {} + @Override public void deconstruct() {} } diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/FilterUtil.java b/container-disc/src/main/java/com/yahoo/container/jdisc/FilterUtil.java deleted file mode 100644 index 52829d6710e..00000000000 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/FilterUtil.java +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.container.jdisc; - -import com.yahoo.component.ComponentId; -import com.yahoo.component.ComponentSpecification; -import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.container.http.filter.FilterChainRepository; -import com.yahoo.jdisc.http.filter.RequestFilter; -import com.yahoo.jdisc.http.filter.ResponseFilter; -import com.yahoo.jdisc.http.filter.SecurityRequestFilter; -import com.yahoo.jdisc.http.filter.SecurityRequestFilterChain; -import com.yahoo.jdisc.http.server.jetty.FilterBindings; - -import java.util.List; - -/** - * Helper class to set up filter binding repositories based on config. - * - * @author Øyvind Bakksjø - * @author bjorncs - */ -class FilterUtil { - - private static final ComponentId SEARCH_SERVER_COMPONENT_ID = ComponentId.fromString("SearchServer"); - - private final FilterBindings.Builder builder = new FilterBindings.Builder(); - - private FilterUtil() {} - - private void configureFilters(List<FilterSpec> filtersConfig, FilterChainRepository filterChainRepository) { - for (FilterSpec filterConfig : filtersConfig) { - Object filter = filterChainRepository.getFilter(ComponentSpecification.fromString(filterConfig.getId())); - if (filter == null) { - throw new RuntimeException("No http filter with id " + filterConfig.getId()); - } - addFilter(filter, filterConfig.getBinding(), filterConfig.getId()); - } - } - - private void addFilter(Object filter, String binding, String filterId) { - if (filter instanceof RequestFilter && filter instanceof ResponseFilter) { - throw new RuntimeException("The filter " + filter.getClass().getName() + - " is unsupported since it's both a RequestFilter and a ResponseFilter."); - } else if (filter instanceof RequestFilter) { - builder.addRequestFilter(filterId, binding, (RequestFilter) filter); - } else if (filter instanceof ResponseFilter) { - builder.addResponseFilter(filterId, binding, (ResponseFilter) filter); - } else { - throw new RuntimeException("Unknown filter type " + filter.getClass().getName()); - } - } - - // TODO(gjoranv): remove - private void configureLegacyFilters(ComponentId id, ComponentRegistry<SecurityRequestFilter> legacyRequestFilters) { - ComponentId serverName = id.getNamespace(); - if (SEARCH_SERVER_COMPONENT_ID.equals(serverName) && !legacyRequestFilters.allComponents().isEmpty()) { - builder.addRequestFilter( - "legacy-filters", "http://*/*", SecurityRequestFilterChain.newInstance(legacyRequestFilters.allComponents())); - } - } - - /** - * Populates binding repositories with filters based on config. - */ - public static FilterBindings setupFilters( - ComponentId componentId, - ComponentRegistry<SecurityRequestFilter> legacyRequestFilters, - List<FilterSpec> filtersConfig, - FilterChainRepository filterChainRepository) { - FilterUtil filterUtil = new FilterUtil(); - - // TODO(gjoranv): remove - filterUtil.configureLegacyFilters(componentId, legacyRequestFilters); - - filterUtil.configureFilters(filtersConfig, filterChainRepository); - - return filterUtil.builder.build(); - } - - public static class FilterSpec { - - private final String id; - private final String binding; - - public FilterSpec(String id, String binding) { - this.id = id; - this.binding = binding; - } - - public String getId() { - return id; - } - - public String getBinding() { - return binding; - } - } - -} diff --git a/jdisc_http_service/abi-spec.json b/jdisc_http_service/abi-spec.json index c19ce5ce83c..8bf7f30964a 100644 --- a/jdisc_http_service/abi-spec.json +++ b/jdisc_http_service/abi-spec.json @@ -740,6 +740,8 @@ "public com.yahoo.jdisc.http.ServerConfig$Builder removeRawPostBodyForWwwUrlEncodedPost(boolean)", "public com.yahoo.jdisc.http.ServerConfig$Builder filter(com.yahoo.jdisc.http.ServerConfig$Filter$Builder)", "public com.yahoo.jdisc.http.ServerConfig$Builder filter(java.util.List)", + "public com.yahoo.jdisc.http.ServerConfig$Builder defaultFilters(com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder)", + "public com.yahoo.jdisc.http.ServerConfig$Builder defaultFilters(java.util.List)", "public com.yahoo.jdisc.http.ServerConfig$Builder maxWorkerThreads(int)", "public com.yahoo.jdisc.http.ServerConfig$Builder minWorkerThreads(int)", "public com.yahoo.jdisc.http.ServerConfig$Builder stopTimeout(double)", @@ -754,11 +756,43 @@ ], "fields": [ "public java.util.List filter", + "public java.util.List defaultFilters", "public com.yahoo.jdisc.http.ServerConfig$Jmx$Builder jmx", "public com.yahoo.jdisc.http.ServerConfig$Metric$Builder metric", "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder accessLog" ] }, + "com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.config.ConfigBuilder" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()", + "public void <init>(com.yahoo.jdisc.http.ServerConfig$DefaultFilters)", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder filterId(java.lang.String)", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder localPort(int)", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters build()" + ], + "fields": [] + }, + "com.yahoo.jdisc.http.ServerConfig$DefaultFilters": { + "superClass": "com.yahoo.config.InnerNode", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void <init>(com.yahoo.jdisc.http.ServerConfig$DefaultFilters$Builder)", + "public java.lang.String filterId()", + "public int localPort()" + ], + "fields": [] + }, "com.yahoo.jdisc.http.ServerConfig$Filter$Builder": { "superClass": "java.lang.Object", "interfaces": [ @@ -894,6 +928,8 @@ "public boolean removeRawPostBodyForWwwUrlEncodedPost()", "public java.util.List filter()", "public com.yahoo.jdisc.http.ServerConfig$Filter filter(int)", + "public java.util.List defaultFilters()", + "public com.yahoo.jdisc.http.ServerConfig$DefaultFilters defaultFilters(int)", "public int maxWorkerThreads()", "public int minWorkerThreads()", "public double stopTimeout()", diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java index 301c92a4583..310f3c9a646 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java @@ -3,18 +3,15 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.jdisc.application.BindingRepository; import com.yahoo.jdisc.application.BindingSet; -import com.yahoo.jdisc.application.UriPattern; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; import java.net.URI; import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.Optional; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -import static java.util.stream.Collectors.toSet; +import java.util.TreeMap; /** * Resolves request/response filter (chain) from a {@link URI} instance. @@ -24,87 +21,82 @@ import static java.util.stream.Collectors.toSet; */ public class FilterBindings { - private final BindingSet<FilterHolder<RequestFilter>> requestFilters; - private final BindingSet<FilterHolder<ResponseFilter>> responseFilters; + private final Map<String, RequestFilter> requestFilters; + private final Map<String, ResponseFilter> responseFilters; + private final Map<Integer, String> defaultRequestFilters; + private final Map<Integer, String> defaultResponseFilters; + private final BindingSet<String> requestFilterBindings; + private final BindingSet<String> responseFilterBindings; private FilterBindings( - BindingSet<FilterHolder<RequestFilter>> requestFilters, - BindingSet<FilterHolder<ResponseFilter>> responseFilters) { + Map<String, RequestFilter> requestFilters, + Map<String, ResponseFilter> responseFilters, + Map<Integer, String> defaultRequestFilters, + Map<Integer, String> defaultResponseFilters, + BindingSet<String> requestFilterBindings, + BindingSet<String> responseFilterBindings) { this.requestFilters = requestFilters; this.responseFilters = responseFilters; + this.defaultRequestFilters = defaultRequestFilters; + this.defaultResponseFilters = defaultResponseFilters; + this.requestFilterBindings = requestFilterBindings; + this.responseFilterBindings = responseFilterBindings; } - public Optional<String> resolveRequestFilter(URI uri) { return resolveFilterId(requestFilters, uri); } - - public Optional<String> resolveResponseFilter(URI uri) { return resolveFilterId(responseFilters, uri); } - - public RequestFilter getRequestFilter(String filterId) { return getFilterInstance(requestFilters, filterId); } - - public ResponseFilter getResponseFilter(String filterId) { return getFilterInstance(responseFilters, filterId); } - - public Collection<String> requestFilterIds() { return filterIds(requestFilters); } - - public Collection<String> responseFilterIds() { return filterIds(responseFilters); } + public Optional<String> resolveRequestFilter(URI uri, int localPort) { + String filterId = requestFilterBindings.resolve(uri); + if (filterId != null) return Optional.of(filterId); + return Optional.ofNullable(defaultRequestFilters.get(localPort)); + } - public Collection<RequestFilter> requestFilters() { return filters(requestFilters); } + public Optional<String> resolveResponseFilter(URI uri, int localPort) { + String filterId = responseFilterBindings.resolve(uri); + if (filterId != null) return Optional.of(filterId); + return Optional.ofNullable(defaultResponseFilters.get(localPort)); + } - public Collection<ResponseFilter> responseFilters() { return filters(responseFilters); } + public RequestFilter getRequestFilter(String filterId) { return requestFilters.get(filterId); } - private static <T> Optional<String> resolveFilterId(BindingSet<FilterHolder<T>> filters, URI uri) { - return Optional.ofNullable(filters.resolve(uri)) - .map(holder -> holder.filterId); - } + public ResponseFilter getResponseFilter(String filterId) { return responseFilters.get(filterId); } - private static <T> T getFilterInstance(BindingSet<FilterHolder<T>> filters, String filterId) { - return stream(filters) - .filter(filterEntry -> filterId.equals(filterEntry.getValue().filterId)) - .map(filterEntry -> filterEntry.getValue().filterInstance) - .findAny() - .orElseThrow(() -> new IllegalArgumentException("No filter with id " + filterId)); - } + public Collection<String> requestFilterIds() { return requestFilters.keySet(); } - private static <T> Collection<String> filterIds(BindingSet<FilterHolder<T>> filters) { - return stream(filters) - .map(filterEntry -> filterEntry.getValue().filterId) - .collect(toSet()); - } + public Collection<String> responseFilterIds() { return responseFilters.keySet(); } - private static <T> Collection<T> filters(BindingSet<FilterHolder<T>> filters) { - return stream(filters) - .map(filterEntry -> filterEntry.getValue().filterInstance) - .collect(toSet()); - } + public Collection<RequestFilter> requestFilters() { return requestFilters.values(); } - private static <T> Stream<Map.Entry<UriPattern, FilterHolder<T>>> stream(BindingSet<FilterHolder<T>> filters) { - return StreamSupport.stream(filters.spliterator(), false); - } + public Collection<ResponseFilter> responseFilters() { return responseFilters.values(); } public static class Builder { - private final BindingRepository<FilterHolder<RequestFilter>> requestFilters = new BindingRepository<>(); - private final BindingRepository<FilterHolder<ResponseFilter>> responseFilters = new BindingRepository<>(); + private final Map<String, RequestFilter> requestFilters = new TreeMap<>(); + private final Map<String, ResponseFilter> responseFilters = new TreeMap<>(); + private final Map<Integer, String> defaultRequestFilters = new TreeMap<>(); + private final Map<Integer, String> defaultResponseFilters = new TreeMap<>(); + private final BindingRepository<String> requestFilterBindings = new BindingRepository<>(); + private final BindingRepository<String> responseFilterBindings = new BindingRepository<>(); public Builder() {} - public Builder addRequestFilter(String id, String binding, RequestFilter filter) { - requestFilters.bind(binding, new FilterHolder<>(id, filter)); - return this; - } + public Builder addRequestFilter(String id, RequestFilter filter) { requestFilters.put(id, filter); return this; } - public Builder addResponseFilter(String id, String binding, ResponseFilter filter) { - responseFilters.bind(binding, new FilterHolder<>(id, filter)); - return this; - } + public Builder addResponseFilter(String id, ResponseFilter filter) { responseFilters.put(id, filter); return this; } - public FilterBindings build() { return new FilterBindings(requestFilters.activate(), responseFilters.activate()); } - } + public Builder addRequestFilterBinding(String id, String binding) { requestFilterBindings.bind(binding, id); return this; } + + public Builder addResponseFilterBinding(String id, String binding) { responseFilterBindings.bind(binding, id); return this; } + + public Builder setRequestFilterDefaultForPort(String id, int port) { defaultRequestFilters.put(port, id); return this; } - private static class FilterHolder<T> { - final String filterId; - final T filterInstance; + public Builder setResponseFilterDefaultForPort(String id, int port) { defaultResponseFilters.put(port, id); return this; } - FilterHolder(String filterId, T filterInstance) { - this.filterId = filterId; - this.filterInstance = filterInstance; + public FilterBindings build() { + return new FilterBindings( + Collections.unmodifiableMap(requestFilters), + Collections.unmodifiableMap(responseFilters), + Collections.unmodifiableMap(defaultRequestFilters), + Collections.unmodifiableMap(defaultResponseFilters), + requestFilterBindings.activate(), + responseFilterBindings.activate()); } } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java new file mode 100644 index 00000000000..badb0e736ae --- /dev/null +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java @@ -0,0 +1,35 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.server.jetty; + +import com.yahoo.jdisc.http.filter.RequestFilter; +import com.yahoo.jdisc.http.filter.ResponseFilter; + +import javax.servlet.http.HttpServletRequest; +import java.net.URI; +import java.util.Optional; + +import static com.yahoo.jdisc.http.server.jetty.JDiscHttpServlet.getConnector; + +/** + * Resolve request/response filter (chain) based on {@link FilterBindings}. + * + * @author bjorncs + */ +class FilterResolver { + + private final FilterBindings bindings; + + FilterResolver(FilterBindings bindings) { + this.bindings = bindings; + } + + Optional<RequestFilter> resolveRequestFilter(HttpServletRequest servletRequest, URI jdiscUri) { + return bindings.resolveRequestFilter(jdiscUri, getConnector(servletRequest).listenPort()) + .map(bindings::getRequestFilter); + } + + Optional<ResponseFilter> resolveResponseFilter(HttpServletRequest servletRequest, URI jdiscUri) { + return bindings.resolveResponseFilter(jdiscUri, getConnector(servletRequest).listenPort()) + .map(bindings::getResponseFilter); + } +} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java index 7f761d6ab4a..0923e6688f4 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java @@ -16,6 +16,7 @@ import com.yahoo.jdisc.http.core.CompletionHandlers; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; +import javax.servlet.http.HttpServletRequest; import java.nio.ByteBuffer; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -41,10 +42,12 @@ class FilteringRequestHandler extends AbstractRequestHandler { }; - private final FilterBindings filterBindings; + private final FilterResolver filterResolver; + private final HttpServletRequest servletRequest; - public FilteringRequestHandler(FilterBindings filterBindings) { - this.filterBindings = filterBindings; + public FilteringRequestHandler(FilterResolver filterResolver, HttpServletRequest servletRequest) { + this.filterResolver = filterResolver; + this.servletRequest = servletRequest; } @Override @@ -52,12 +55,11 @@ class FilteringRequestHandler extends AbstractRequestHandler { Preconditions.checkArgument(request instanceof HttpRequest, "Expected HttpRequest, got " + request); Objects.requireNonNull(originalResponseHandler, "responseHandler"); - RequestFilter requestFilter = filterBindings.resolveRequestFilter(request.getUri()) - .map(filterBindings::getRequestFilter) + RequestFilter requestFilter = filterResolver.resolveRequestFilter(servletRequest, request.getUri()) .orElse(null); - ResponseFilter responseFilter = filterBindings.resolveResponseFilter(request.getUri()) - .map(filterBindings::getResponseFilter) + ResponseFilter responseFilter = filterResolver.resolveResponseFilter(servletRequest, request.getUri()) .orElse(null); + // Not using request.connect() here - it adds logic for error handling that we'd rather leave to the framework. RequestHandler resolvedRequestHandler = request.container().resolveHandler(request); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java index ed5095fb06f..940009e7520 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java @@ -212,7 +212,7 @@ class HttpRequestDispatch { AccessLogEntry accessLogEntry, HttpServletRequest servletRequest) { RequestHandler requestHandler = wrapHandlerIfFormPost( - new FilteringRequestHandler(context.filterBindings), + new FilteringRequestHandler(context.filterResolver, servletRequest), servletRequest, context.serverConfig.removeRawPostBodyForWwwUrlEncodedPost()); return new AccessLoggingRequestHandler(requestHandler, accessLogEntry); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java index 4667ff3975b..a308149beb5 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java @@ -8,7 +8,7 @@ import com.yahoo.jdisc.service.CurrentContainer; import java.util.concurrent.Executor; public class JDiscContext { - final FilterBindings filterBindings; + final FilterResolver filterResolver; final CurrentContainer container; final Executor janitor; final Metric metric; @@ -20,7 +20,7 @@ public class JDiscContext { Metric metric, ServerConfig serverConfig) { - this.filterBindings = filterBindings; + this.filterResolver = new FilterResolver(filterBindings); this.container = container; this.janitor = janitor; this.metric = metric; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java index 5a904299e44..f046ccd5439 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java @@ -3,6 +3,7 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.jdisc.handler.ResponseHandler; +import com.yahoo.jdisc.http.filter.RequestFilter; import javax.servlet.AsyncContext; import javax.servlet.AsyncListener; @@ -75,8 +76,7 @@ class JDiscFilterInvokerFilter implements Filter { private void runChainAndResponseFilters(URI uri, HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { Optional<OneTimeRunnable> responseFilterInvoker = - jDiscContext.filterBindings.resolveResponseFilter(uri) - .map(jDiscContext.filterBindings::getResponseFilter) + jDiscContext.filterResolver.resolveResponseFilter(request, uri) .map(responseFilter -> new OneTimeRunnable(() -> filterInvoker.invokeResponseFilterChain(responseFilter, uri, request, response))); @@ -106,12 +106,12 @@ class JDiscFilterInvokerFilter implements Filter { private HttpServletRequest runRequestFilterWithMatchingBinding(AtomicReference<Boolean> responseReturned, URI uri, HttpServletRequest request, HttpServletResponse response) throws IOException { try { - String requestFilterId = jDiscContext.filterBindings.resolveRequestFilter(uri).orElse(null); - if (requestFilterId == null) + RequestFilter requestFilter = jDiscContext.filterResolver.resolveRequestFilter(request, uri).orElse(null); + if (requestFilter == null) return request; ResponseHandler responseHandler = createResponseHandler(responseReturned, request, response); - return filterInvoker.invokeRequestFilterChain(jDiscContext.filterBindings.getRequestFilter(requestFilterId), uri, request, responseHandler); + return filterInvoker.invokeRequestFilterChain(requestFilter, uri, request, responseHandler); } catch (Exception e) { throw new RuntimeException("Failed running request filter chain for uri " + uri, e); } diff --git a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def index 81f9859fabd..f33dc35ea0b 100644 --- a/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def +++ b/jdisc_http_service/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def @@ -27,6 +27,12 @@ filter[].id string # The binding of a filter filter[].binding string +# Filter id for a default filter (chain) +defaultFilters[].filterId string + +# The local port which the default filter should be applied to +defaultFilters[].localPort int + # Max number of threads in underlying Jetty pool maxWorkerThreads int default = 200 diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java index a978e42f7cb..b2e28c6af67 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/FilterTestCase.java @@ -46,14 +46,16 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** - * @author bakksjo + * @author Oyvind Bakksjo + * @author bjorncs */ public class FilterTestCase { @Test public void requireThatRequestFilterIsNotRunOnUnboundPath() throws Exception { RequestFilterMockBase filter = mock(RequestFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/filtered/*", filter) + .addRequestFilter("my-request-filter", filter) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -70,7 +72,8 @@ public class FilterTestCase { public void requireThatRequestFilterIsRunOnBoundPath() throws Exception { final RequestFilter filter = mock(RequestFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/filtered/*", filter) + .addRequestFilter("my-request-filter", filter) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -87,7 +90,8 @@ public class FilterTestCase { public void requireThatRequestFilterChangesAreSeenByRequestHandler() throws Exception { final RequestFilter filter = new HeaderRequestFilter("foo", "bar"); FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", filter) + .addRequestFilter("my-request-filter", filter) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -103,7 +107,8 @@ public class FilterTestCase { @Test public void requireThatRequestFilterCanRespond() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new RespondForbiddenFilter()) + .addRequestFilter("my-request-filter", new RespondForbiddenFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -120,7 +125,8 @@ public class FilterTestCase { final int responseStatus = Response.Status.OK; final String responseMessage = "Excellent"; FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new NullCompletionHandlerFilter(responseStatus, responseMessage)) + .addRequestFilter("my-request-filter", new NullCompletionHandlerFilter(responseStatus, responseMessage)) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -137,7 +143,8 @@ public class FilterTestCase { @Test public void requireThatRequestFilterExecutionIsExceptionSafe() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new ThrowingRequestFilter()) + .addRequestFilter("my-request-filter", new ThrowingRequestFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -153,7 +160,8 @@ public class FilterTestCase { public void requireThatResponseFilterIsNotRunOnUnboundPath() throws Exception { final ResponseFilter filter = mock(ResponseFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/filtered/*", filter) + .addResponseFilter("my-response-filter", filter) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -170,7 +178,8 @@ public class FilterTestCase { public void requireThatResponseFilterIsRunOnBoundPath() throws Exception { final ResponseFilter filter = mock(ResponseFilterMockBase.class); FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/filtered/*", filter) + .addResponseFilter("my-response-filter", filter) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -186,7 +195,8 @@ public class FilterTestCase { @Test public void requireThatResponseFilterChangesAreWrittenToResponse() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/*", new HeaderResponseFilter("foo", "bar")) + .addResponseFilter("my-response-filter", new HeaderResponseFilter("foo", "bar")) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -202,7 +212,8 @@ public class FilterTestCase { @Test public void requireThatResponseFilterExecutionIsExceptionSafe() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/*", new ThrowingResponseFilter()) + .addResponseFilter("my-response-filter", new ThrowingResponseFilter()) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -220,8 +231,10 @@ public class FilterTestCase { final ResponseFilter responseFilter = mock(ResponseFilterMockBase.class); final String uriPattern = "http://*/*"; FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", uriPattern, requestFilter) - .addResponseFilter("my-response-filter", uriPattern, responseFilter) + .addRequestFilter("my-request-filter", requestFilter) + .addRequestFilterBinding("my-request-filter", uriPattern) + .addResponseFilter("my-response-filter", responseFilter) + .addResponseFilterBinding("my-response-filter", uriPattern) .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -238,8 +251,10 @@ public class FilterTestCase { @Test public void requireThatResponseFromRequestFilterGoesThroughResponseFilter() throws Exception { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new RespondForbiddenFilter()) - .addResponseFilter("my-response-filter", "http://*/*", new HeaderResponseFilter("foo", "bar")) + .addRequestFilter("my-request-filter", new RespondForbiddenFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") + .addResponseFilter("my-response-filter", new HeaderResponseFilter("foo", "bar")) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); final MyRequestHandler requestHandler = new MyRequestHandler(); final TestDriver testDriver = newDriver(requestHandler, filterBindings); @@ -380,10 +395,111 @@ public class FilterTestCase { assertThat(response.headers().getFirst("foo"), is("bar")); } - private static TestDriver newDriver( - final MyRequestHandler requestHandler, - FilterBindings filterBindings) - throws IOException { + @Test + public void requireThatDefaultRequestFilterChainIsRunIfNoOtherFilterChainMatches() throws IOException, InterruptedException { + RequestFilter filterWithBinding = mock(RequestFilter.class); + RequestFilter defaultFilter = mock(RequestFilter.class); + String defaultFilterId = "default-request-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addRequestFilter("my-request-filter", filterWithBinding) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") + .addRequestFilter(defaultFilterId, defaultFilter) + .setRequestFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, times(1)).filter(any(HttpRequest.class), any(ResponseHandler.class)); + verify(filterWithBinding, never()).filter(any(HttpRequest.class), any(ResponseHandler.class)); + + assertThat(testDriver.close(), is(true)); + } + + @Test + public void requireThatDefaultResponseFilterChainIsRunIfNoOtherFilterChainMatches() throws IOException, InterruptedException { + ResponseFilter filterWithBinding = mock(ResponseFilter.class); + ResponseFilter defaultFilter = mock(ResponseFilter.class); + String defaultFilterId = "default-response-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addResponseFilter("my-response-filter", filterWithBinding) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") + .addResponseFilter(defaultFilterId, defaultFilter) + .setResponseFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, times(1)).filter(any(Response.class), any(Request.class)); + verify(filterWithBinding, never()).filter(any(Response.class), any(Request.class)); + + assertThat(testDriver.close(), is(true)); + } + + @Test + public void requireThatRequestFilterWithBindingMatchHasPrecedenceOverDefaultFilter() throws IOException, InterruptedException { + RequestFilterMockBase filterWithBinding = mock(RequestFilterMockBase.class); + RequestFilterMockBase defaultFilter = mock(RequestFilterMockBase.class); + String defaultFilterId = "default-request-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addRequestFilter("my-request-filter", filterWithBinding) + .addRequestFilterBinding("my-request-filter", "http://*/filtered/*") + .addRequestFilter(defaultFilterId, defaultFilter) + .setRequestFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/filtered/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, never()).filter(any(HttpRequest.class), any(ResponseHandler.class)); + verify(filterWithBinding).filter(any(HttpRequest.class), any(ResponseHandler.class)); + + assertThat(testDriver.close(), is(true)); + } + + @Test + public void requireThatResponsFilterWithBindingMatchHasPrecedenceOverDefaultFilter() throws IOException, InterruptedException { + ResponseFilter filterWithBinding = mock(ResponseFilter.class); + ResponseFilter defaultFilter = mock(ResponseFilter.class); + String defaultFilterId = "default-response-filter"; + FilterBindings filterBindings = new FilterBindings.Builder() + .addResponseFilter("my-response-filter", filterWithBinding) + .addResponseFilterBinding("my-response-filter", "http://*/filtered/*") + .addResponseFilter(defaultFilterId, defaultFilter) + .setResponseFilterDefaultForPort(defaultFilterId, 0) + .build(); + MyRequestHandler requestHandler = new MyRequestHandler(); + TestDriver testDriver = TestDriver.newInstance( + JettyHttpServer.class, + requestHandler, + newFilterModule(filterBindings)); + + testDriver.client().get("/filtered/status.html"); + + assertThat(requestHandler.awaitInvocation(), is(true)); + verify(defaultFilter, never()).filter(any(Response.class), any(Request.class)); + verify(filterWithBinding, times(1)).filter(any(Response.class), any(Request.class)); + + assertThat(testDriver.close(), is(true)); + } + + private static TestDriver newDriver(MyRequestHandler requestHandler, FilterBindings filterBindings) { return TestDriver.newInstance( JettyHttpServer.class, requestHandler, diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java index 272d6fbb66c..16969a47b84 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/server/jetty/servlet/JDiscFilterForServletTest.java @@ -31,6 +31,7 @@ import static org.hamcrest.CoreMatchers.is; /** * @author Tony Vaagenes + * @author bjorncs */ public class JDiscFilterForServletTest extends ServletTestBase { @Test @@ -79,14 +80,16 @@ public class JDiscFilterForServletTest extends ServletTestBase { private TestDriver requestFilterTestDriver() throws IOException { FilterBindings filterBindings = new FilterBindings.Builder() - .addRequestFilter("my-request-filter", "http://*/*", new TestRequestFilter()) + .addRequestFilter("my-request-filter", new TestRequestFilter()) + .addRequestFilterBinding("my-request-filter", "http://*/*") .build(); return TestDrivers.newInstance(dummyRequestHandler, bindings(filterBindings)); } private TestDriver responseFilterTestDriver() throws IOException { FilterBindings filterBindings = new FilterBindings.Builder() - .addResponseFilter("my-response-filter", "http://*/*", new TestResponseFilter()) + .addResponseFilter("my-response-filter", new TestResponseFilter()) + .addResponseFilterBinding("my-response-filter", "http://*/*") .build(); return TestDrivers.newInstance(dummyRequestHandler, bindings(filterBindings)); } |