From 244f51c38338b66ff2f7b01dbef268f76aa3549d Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Wed, 11 Nov 2020 16:23:46 +0100 Subject: Revert "Remove implicit and potensially false dependency on jdisc_http_service" --- container-core/pom.xml | 6 - .../http/filter/FilterChainRepository.java | 201 +++++++++++++++++++++ 2 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 container-core/src/main/java/com/yahoo/container/http/filter/FilterChainRepository.java (limited to 'container-core') diff --git a/container-core/pom.xml b/container-core/pom.xml index 68556c70e22..0fbb590a1de 100644 --- a/container-core/pom.xml +++ b/container-core/pom.xml @@ -99,12 +99,6 @@ ${project.version} provided - - com.yahoo.vespa - jdisc_http_service - ${project.version} - provided - com.yahoo.vespa linguistics diff --git a/container-core/src/main/java/com/yahoo/container/http/filter/FilterChainRepository.java b/container-core/src/main/java/com/yahoo/container/http/filter/FilterChainRepository.java new file mode 100644 index 00000000000..31bceca9337 --- /dev/null +++ b/container-core/src/main/java/com/yahoo/container/http/filter/FilterChainRepository.java @@ -0,0 +1,201 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.container.http.filter; + +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.ComponentId; +import com.yahoo.component.ComponentSpecification; +import com.yahoo.component.chain.Chain; +import com.yahoo.component.chain.ChainedComponent; +import com.yahoo.component.chain.ChainsConfigurer; +import com.yahoo.component.chain.model.ChainsModel; +import com.yahoo.component.chain.model.ChainsModelBuilder; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.container.core.ChainsConfig; +import com.yahoo.jdisc.http.filter.RequestFilter; +import com.yahoo.jdisc.http.filter.ResponseFilter; +import com.yahoo.jdisc.http.filter.SecurityRequestFilter; +import com.yahoo.jdisc.http.filter.SecurityRequestFilterChain; +import com.yahoo.jdisc.http.filter.SecurityResponseFilter; +import com.yahoo.jdisc.http.filter.SecurityResponseFilterChain; +import com.yahoo.jdisc.http.filter.chain.RequestFilterChain; +import com.yahoo.jdisc.http.filter.chain.ResponseFilterChain; +import com.yahoo.processing.execution.chain.ChainRegistry; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.logging.Logger; + +import static java.util.Collections.emptyList; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; + +/** + * Creates JDisc request/response filter chains. + * + * @author Tony Vaagenes + * @author bjorncs + */ +public class FilterChainRepository extends AbstractComponent { + private static final Logger log = Logger.getLogger(FilterChainRepository.class.getName()); + + private final ComponentRegistry filterAndChains; + + public FilterChainRepository(ChainsConfig chainsConfig, + ComponentRegistry requestFilters, + ComponentRegistry responseFilters, + ComponentRegistry securityRequestFilters, + ComponentRegistry securityResponseFilters) { + ComponentRegistry filterAndChains = new ComponentRegistry<>(); + addAllFilters(filterAndChains, requestFilters, responseFilters, securityRequestFilters, securityResponseFilters); + addAllChains(filterAndChains, chainsConfig, requestFilters, responseFilters, securityRequestFilters, securityResponseFilters); + filterAndChains.freeze(); + this.filterAndChains = filterAndChains; + } + + public Object getFilter(ComponentSpecification componentSpecification) { + return filterAndChains.getComponent(componentSpecification); + } + + private static void addAllFilters(ComponentRegistry destination, + ComponentRegistry... registries) { + for (ComponentRegistry registry : registries) { + registry.allComponentsById() + .forEach((id, filter) -> destination.register(id, wrapIfSecurityFilter(filter))); + } + } + + private static void addAllChains(ComponentRegistry destination, + ChainsConfig chainsConfig, + ComponentRegistry... filters) { + ChainRegistry chainRegistry = buildChainRegistry(chainsConfig, filters); + chainRegistry.allComponents() + .forEach(chain -> destination.register(chain.getId(), toJDiscChain(chain))); + } + + private static ChainRegistry buildChainRegistry(ChainsConfig chainsConfig, + ComponentRegistry... filters) { + ChainRegistry chainRegistry = new ChainRegistry<>(); + ChainsModel chainsModel = ChainsModelBuilder.buildFromConfig(chainsConfig); + ChainsConfigurer.prepareChainRegistry(chainRegistry, chainsModel, allFiltersWrapped(filters)); + removeEmptyChains(chainRegistry); + chainRegistry.freeze(); + return chainRegistry; + } + + private static void removeEmptyChains(ChainRegistry chainRegistry) { + chainRegistry.allComponents().stream() + .filter(chain -> chain.components().isEmpty()) + .map(Chain::getId) + .peek(id -> log.warning("Removing empty filter chain: " + id)) + .forEach(chainRegistry::unregister); + } + + @SuppressWarnings("unchecked") + private static Object toJDiscChain(Chain chain) { + if (chain.components().isEmpty()) + throw new IllegalArgumentException("Empty filter chain: " + chain.getId()); + checkFilterTypesCompatible(chain); + List jdiscFilters = chain.components().stream() + .map(filterWrapper -> filterWrapper.filter) + .collect(toList()); + List wrappedFilters = wrapSecurityFilters(jdiscFilters); + Object head = wrappedFilters.get(0); + if (wrappedFilters.size() == 1) return head; + else if (head instanceof RequestFilter) + return RequestFilterChain.newInstance((List) wrappedFilters); + else if (head instanceof ResponseFilter) + return ResponseFilterChain.newInstance((List) wrappedFilters); + throw new IllegalStateException(); + } + + private static List wrapSecurityFilters(List filters) { + List aggregatedSecurityFilters = new ArrayList<>(); + List wrappedFilters = new ArrayList<>(); + for (Object filter : filters) { + if (isSecurityFilter(filter)) { + aggregatedSecurityFilters.add(filter); + } else { + if (!aggregatedSecurityFilters.isEmpty()) { + wrappedFilters.add(createSecurityChain(aggregatedSecurityFilters)); + aggregatedSecurityFilters.clear(); + } + wrappedFilters.add(filter); + } + } + if (!aggregatedSecurityFilters.isEmpty()) { + wrappedFilters.add(createSecurityChain(aggregatedSecurityFilters)); + } + return wrappedFilters; + } + + private static void checkFilterTypesCompatible(Chain chain) { + Set requestFilters = chain.components().stream() + .filter(filter -> filter instanceof RequestFilter || filter instanceof SecurityRequestFilter) + .map(FilterWrapper::getId) + .collect(toSet()); + Set responseFilters = chain.components().stream() + .filter(filter -> filter instanceof ResponseFilter || filter instanceof SecurityResponseFilter) + .map(FilterWrapper::getId) + .collect(toSet()); + if (!requestFilters.isEmpty() && !responseFilters.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "Can't mix request and response filters in chain %s: request filters: %s, response filters: %s.", + chain.getId(), requestFilters, responseFilters)); + } + } + + private static ComponentRegistry allFiltersWrapped(ComponentRegistry... registries) { + ComponentRegistry wrappedFilters = new ComponentRegistry<>(); + for (ComponentRegistry registry : registries) { + registry.allComponentsById() + .forEach((id, filter) -> wrappedFilters.register(id, new FilterWrapper(id, filter))); + } + wrappedFilters.freeze(); + return wrappedFilters; + } + + private static Object wrapIfSecurityFilter(Object filter) { + if (isSecurityFilter(filter)) return createSecurityChain(Collections.singletonList(filter)); + return filter; + } + + @SuppressWarnings("unchecked") + private static Object createSecurityChain(List filters) { + Object head = filters.get(0); + if (head instanceof SecurityRequestFilter) + return SecurityRequestFilterChain.newInstance((List) filters); + else if (head instanceof SecurityResponseFilter) + return SecurityResponseFilterChain.newInstance((List) filters); + throw new IllegalArgumentException("Unexpected class " + head.getClass()); + } + + private static boolean isSecurityFilter(Object filter) { + return filter instanceof SecurityRequestFilter || filter instanceof SecurityResponseFilter; + } + + private static class FilterWrapper extends ChainedComponent { + public final Object filter; + public final Class filterType; + + public FilterWrapper(ComponentId id, Object filter) { + super(id); + this.filter = filter; + this.filterType = getFilterType(filter); + } + + private static Class getFilterType(Object filter) { + if (filter instanceof RequestFilter) + return RequestFilter.class; + else if (filter instanceof ResponseFilter) + return ResponseFilter.class; + else if (filter instanceof SecurityRequestFilter) + return SecurityRequestFilter.class; + else if (filter instanceof SecurityResponseFilter) + return SecurityResponseFilter.class; + throw new IllegalArgumentException("Unsupported filter type: " + filter.getClass().getName()); + } + } +} -- cgit v1.2.3