diff options
Diffstat (limited to 'jdisc_http_service/src/main/java/com/yahoo/jdisc')
6 files changed, 108 insertions, 79 deletions
diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java index 301c92a4583..310f3c9a646 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterBindings.java @@ -3,18 +3,15 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.jdisc.application.BindingRepository; import com.yahoo.jdisc.application.BindingSet; -import com.yahoo.jdisc.application.UriPattern; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; import java.net.URI; import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.Optional; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -import static java.util.stream.Collectors.toSet; +import java.util.TreeMap; /** * Resolves request/response filter (chain) from a {@link URI} instance. @@ -24,87 +21,82 @@ import static java.util.stream.Collectors.toSet; */ public class FilterBindings { - private final BindingSet<FilterHolder<RequestFilter>> requestFilters; - private final BindingSet<FilterHolder<ResponseFilter>> responseFilters; + private final Map<String, RequestFilter> requestFilters; + private final Map<String, ResponseFilter> responseFilters; + private final Map<Integer, String> defaultRequestFilters; + private final Map<Integer, String> defaultResponseFilters; + private final BindingSet<String> requestFilterBindings; + private final BindingSet<String> responseFilterBindings; private FilterBindings( - BindingSet<FilterHolder<RequestFilter>> requestFilters, - BindingSet<FilterHolder<ResponseFilter>> responseFilters) { + Map<String, RequestFilter> requestFilters, + Map<String, ResponseFilter> responseFilters, + Map<Integer, String> defaultRequestFilters, + Map<Integer, String> defaultResponseFilters, + BindingSet<String> requestFilterBindings, + BindingSet<String> responseFilterBindings) { this.requestFilters = requestFilters; this.responseFilters = responseFilters; + this.defaultRequestFilters = defaultRequestFilters; + this.defaultResponseFilters = defaultResponseFilters; + this.requestFilterBindings = requestFilterBindings; + this.responseFilterBindings = responseFilterBindings; } - public Optional<String> resolveRequestFilter(URI uri) { return resolveFilterId(requestFilters, uri); } - - public Optional<String> resolveResponseFilter(URI uri) { return resolveFilterId(responseFilters, uri); } - - public RequestFilter getRequestFilter(String filterId) { return getFilterInstance(requestFilters, filterId); } - - public ResponseFilter getResponseFilter(String filterId) { return getFilterInstance(responseFilters, filterId); } - - public Collection<String> requestFilterIds() { return filterIds(requestFilters); } - - public Collection<String> responseFilterIds() { return filterIds(responseFilters); } + public Optional<String> resolveRequestFilter(URI uri, int localPort) { + String filterId = requestFilterBindings.resolve(uri); + if (filterId != null) return Optional.of(filterId); + return Optional.ofNullable(defaultRequestFilters.get(localPort)); + } - public Collection<RequestFilter> requestFilters() { return filters(requestFilters); } + public Optional<String> resolveResponseFilter(URI uri, int localPort) { + String filterId = responseFilterBindings.resolve(uri); + if (filterId != null) return Optional.of(filterId); + return Optional.ofNullable(defaultResponseFilters.get(localPort)); + } - public Collection<ResponseFilter> responseFilters() { return filters(responseFilters); } + public RequestFilter getRequestFilter(String filterId) { return requestFilters.get(filterId); } - private static <T> Optional<String> resolveFilterId(BindingSet<FilterHolder<T>> filters, URI uri) { - return Optional.ofNullable(filters.resolve(uri)) - .map(holder -> holder.filterId); - } + public ResponseFilter getResponseFilter(String filterId) { return responseFilters.get(filterId); } - private static <T> T getFilterInstance(BindingSet<FilterHolder<T>> filters, String filterId) { - return stream(filters) - .filter(filterEntry -> filterId.equals(filterEntry.getValue().filterId)) - .map(filterEntry -> filterEntry.getValue().filterInstance) - .findAny() - .orElseThrow(() -> new IllegalArgumentException("No filter with id " + filterId)); - } + public Collection<String> requestFilterIds() { return requestFilters.keySet(); } - private static <T> Collection<String> filterIds(BindingSet<FilterHolder<T>> filters) { - return stream(filters) - .map(filterEntry -> filterEntry.getValue().filterId) - .collect(toSet()); - } + public Collection<String> responseFilterIds() { return responseFilters.keySet(); } - private static <T> Collection<T> filters(BindingSet<FilterHolder<T>> filters) { - return stream(filters) - .map(filterEntry -> filterEntry.getValue().filterInstance) - .collect(toSet()); - } + public Collection<RequestFilter> requestFilters() { return requestFilters.values(); } - private static <T> Stream<Map.Entry<UriPattern, FilterHolder<T>>> stream(BindingSet<FilterHolder<T>> filters) { - return StreamSupport.stream(filters.spliterator(), false); - } + public Collection<ResponseFilter> responseFilters() { return responseFilters.values(); } public static class Builder { - private final BindingRepository<FilterHolder<RequestFilter>> requestFilters = new BindingRepository<>(); - private final BindingRepository<FilterHolder<ResponseFilter>> responseFilters = new BindingRepository<>(); + private final Map<String, RequestFilter> requestFilters = new TreeMap<>(); + private final Map<String, ResponseFilter> responseFilters = new TreeMap<>(); + private final Map<Integer, String> defaultRequestFilters = new TreeMap<>(); + private final Map<Integer, String> defaultResponseFilters = new TreeMap<>(); + private final BindingRepository<String> requestFilterBindings = new BindingRepository<>(); + private final BindingRepository<String> responseFilterBindings = new BindingRepository<>(); public Builder() {} - public Builder addRequestFilter(String id, String binding, RequestFilter filter) { - requestFilters.bind(binding, new FilterHolder<>(id, filter)); - return this; - } + public Builder addRequestFilter(String id, RequestFilter filter) { requestFilters.put(id, filter); return this; } - public Builder addResponseFilter(String id, String binding, ResponseFilter filter) { - responseFilters.bind(binding, new FilterHolder<>(id, filter)); - return this; - } + public Builder addResponseFilter(String id, ResponseFilter filter) { responseFilters.put(id, filter); return this; } - public FilterBindings build() { return new FilterBindings(requestFilters.activate(), responseFilters.activate()); } - } + public Builder addRequestFilterBinding(String id, String binding) { requestFilterBindings.bind(binding, id); return this; } + + public Builder addResponseFilterBinding(String id, String binding) { responseFilterBindings.bind(binding, id); return this; } + + public Builder setRequestFilterDefaultForPort(String id, int port) { defaultRequestFilters.put(port, id); return this; } - private static class FilterHolder<T> { - final String filterId; - final T filterInstance; + public Builder setResponseFilterDefaultForPort(String id, int port) { defaultResponseFilters.put(port, id); return this; } - FilterHolder(String filterId, T filterInstance) { - this.filterId = filterId; - this.filterInstance = filterInstance; + public FilterBindings build() { + return new FilterBindings( + Collections.unmodifiableMap(requestFilters), + Collections.unmodifiableMap(responseFilters), + Collections.unmodifiableMap(defaultRequestFilters), + Collections.unmodifiableMap(defaultResponseFilters), + requestFilterBindings.activate(), + responseFilterBindings.activate()); } } } diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java new file mode 100644 index 00000000000..badb0e736ae --- /dev/null +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilterResolver.java @@ -0,0 +1,35 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.server.jetty; + +import com.yahoo.jdisc.http.filter.RequestFilter; +import com.yahoo.jdisc.http.filter.ResponseFilter; + +import javax.servlet.http.HttpServletRequest; +import java.net.URI; +import java.util.Optional; + +import static com.yahoo.jdisc.http.server.jetty.JDiscHttpServlet.getConnector; + +/** + * Resolve request/response filter (chain) based on {@link FilterBindings}. + * + * @author bjorncs + */ +class FilterResolver { + + private final FilterBindings bindings; + + FilterResolver(FilterBindings bindings) { + this.bindings = bindings; + } + + Optional<RequestFilter> resolveRequestFilter(HttpServletRequest servletRequest, URI jdiscUri) { + return bindings.resolveRequestFilter(jdiscUri, getConnector(servletRequest).listenPort()) + .map(bindings::getRequestFilter); + } + + Optional<ResponseFilter> resolveResponseFilter(HttpServletRequest servletRequest, URI jdiscUri) { + return bindings.resolveResponseFilter(jdiscUri, getConnector(servletRequest).listenPort()) + .map(bindings::getResponseFilter); + } +} diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java index 7f761d6ab4a..0923e6688f4 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/FilteringRequestHandler.java @@ -16,6 +16,7 @@ import com.yahoo.jdisc.http.core.CompletionHandlers; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; +import javax.servlet.http.HttpServletRequest; import java.nio.ByteBuffer; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; @@ -41,10 +42,12 @@ class FilteringRequestHandler extends AbstractRequestHandler { }; - private final FilterBindings filterBindings; + private final FilterResolver filterResolver; + private final HttpServletRequest servletRequest; - public FilteringRequestHandler(FilterBindings filterBindings) { - this.filterBindings = filterBindings; + public FilteringRequestHandler(FilterResolver filterResolver, HttpServletRequest servletRequest) { + this.filterResolver = filterResolver; + this.servletRequest = servletRequest; } @Override @@ -52,12 +55,11 @@ class FilteringRequestHandler extends AbstractRequestHandler { Preconditions.checkArgument(request instanceof HttpRequest, "Expected HttpRequest, got " + request); Objects.requireNonNull(originalResponseHandler, "responseHandler"); - RequestFilter requestFilter = filterBindings.resolveRequestFilter(request.getUri()) - .map(filterBindings::getRequestFilter) + RequestFilter requestFilter = filterResolver.resolveRequestFilter(servletRequest, request.getUri()) .orElse(null); - ResponseFilter responseFilter = filterBindings.resolveResponseFilter(request.getUri()) - .map(filterBindings::getResponseFilter) + ResponseFilter responseFilter = filterResolver.resolveResponseFilter(servletRequest, request.getUri()) .orElse(null); + // Not using request.connect() here - it adds logic for error handling that we'd rather leave to the framework. RequestHandler resolvedRequestHandler = request.container().resolveHandler(request); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java index ed5095fb06f..940009e7520 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java @@ -212,7 +212,7 @@ class HttpRequestDispatch { AccessLogEntry accessLogEntry, HttpServletRequest servletRequest) { RequestHandler requestHandler = wrapHandlerIfFormPost( - new FilteringRequestHandler(context.filterBindings), + new FilteringRequestHandler(context.filterResolver, servletRequest), servletRequest, context.serverConfig.removeRawPostBodyForWwwUrlEncodedPost()); return new AccessLoggingRequestHandler(requestHandler, accessLogEntry); diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java index 4667ff3975b..a308149beb5 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscContext.java @@ -8,7 +8,7 @@ import com.yahoo.jdisc.service.CurrentContainer; import java.util.concurrent.Executor; public class JDiscContext { - final FilterBindings filterBindings; + final FilterResolver filterResolver; final CurrentContainer container; final Executor janitor; final Metric metric; @@ -20,7 +20,7 @@ public class JDiscContext { Metric metric, ServerConfig serverConfig) { - this.filterBindings = filterBindings; + this.filterResolver = new FilterResolver(filterBindings); this.container = container; this.janitor = janitor; this.metric = metric; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java index 5a904299e44..f046ccd5439 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java @@ -3,6 +3,7 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.jdisc.handler.ResponseHandler; +import com.yahoo.jdisc.http.filter.RequestFilter; import javax.servlet.AsyncContext; import javax.servlet.AsyncListener; @@ -75,8 +76,7 @@ class JDiscFilterInvokerFilter implements Filter { private void runChainAndResponseFilters(URI uri, HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { Optional<OneTimeRunnable> responseFilterInvoker = - jDiscContext.filterBindings.resolveResponseFilter(uri) - .map(jDiscContext.filterBindings::getResponseFilter) + jDiscContext.filterResolver.resolveResponseFilter(request, uri) .map(responseFilter -> new OneTimeRunnable(() -> filterInvoker.invokeResponseFilterChain(responseFilter, uri, request, response))); @@ -106,12 +106,12 @@ class JDiscFilterInvokerFilter implements Filter { private HttpServletRequest runRequestFilterWithMatchingBinding(AtomicReference<Boolean> responseReturned, URI uri, HttpServletRequest request, HttpServletResponse response) throws IOException { try { - String requestFilterId = jDiscContext.filterBindings.resolveRequestFilter(uri).orElse(null); - if (requestFilterId == null) + RequestFilter requestFilter = jDiscContext.filterResolver.resolveRequestFilter(request, uri).orElse(null); + if (requestFilter == null) return request; ResponseHandler responseHandler = createResponseHandler(responseReturned, request, response); - return filterInvoker.invokeRequestFilterChain(jDiscContext.filterBindings.getRequestFilter(requestFilterId), uri, request, responseHandler); + return filterInvoker.invokeRequestFilterChain(requestFilter, uri, request, responseHandler); } catch (Exception e) { throw new RuntimeException("Failed running request filter chain for uri " + uri, e); } |