// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.http.filter; import com.yahoo.component.AbstractComponent; import com.yahoo.component.ComponentId; import com.yahoo.component.ComponentSpecification; import com.yahoo.component.chain.Chain; import com.yahoo.component.chain.ChainedComponent; import com.yahoo.component.chain.ChainsConfigurer; import com.yahoo.component.chain.model.ChainsModel; import com.yahoo.component.chain.model.ChainsModelBuilder; import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.container.core.ChainsConfig; import com.yahoo.jdisc.http.filter.RequestFilter; import com.yahoo.jdisc.http.filter.ResponseFilter; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; import com.yahoo.jdisc.http.filter.SecurityRequestFilterChain; import com.yahoo.jdisc.http.filter.SecurityResponseFilter; import com.yahoo.jdisc.http.filter.SecurityResponseFilterChain; import com.yahoo.jdisc.http.filter.chain.RequestFilterChain; import com.yahoo.jdisc.http.filter.chain.ResponseFilterChain; import com.yahoo.processing.execution.chain.ChainRegistry; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Set; import java.util.logging.Logger; import static java.util.Collections.emptyList; import static java.util.stream.Collectors.toSet; /** * Creates JDisc request/response filter chains. * * @author Tony Vaagenes * @author bjorncs */ public class FilterChainRepository extends AbstractComponent { private static final Logger log = Logger.getLogger(FilterChainRepository.class.getName()); private final ComponentRegistry filterAndChains; public FilterChainRepository(ChainsConfig chainsConfig, ComponentRegistry requestFilters, ComponentRegistry responseFilters, ComponentRegistry securityRequestFilters, ComponentRegistry securityResponseFilters) { ComponentRegistry filterAndChains = new ComponentRegistry<>(); addAllFilters(filterAndChains, requestFilters, responseFilters, securityRequestFilters, securityResponseFilters); addAllChains(filterAndChains, chainsConfig, requestFilters, responseFilters, securityRequestFilters, securityResponseFilters); filterAndChains.freeze(); this.filterAndChains = filterAndChains; } public Object getFilter(ComponentSpecification componentSpecification) { return filterAndChains.getComponent(componentSpecification); } private static void addAllFilters(ComponentRegistry destination, ComponentRegistry... registries) { for (ComponentRegistry registry : registries) { registry.allComponentsById() .forEach((id, filter) -> destination.register(id, wrapIfSecurityFilter(filter))); } } private static void addAllChains(ComponentRegistry destination, ChainsConfig chainsConfig, ComponentRegistry... filters) { ChainRegistry chainRegistry = buildChainRegistry(chainsConfig, filters); chainRegistry.allComponents() .forEach(chain -> destination.register(chain.getId(), toJDiscChain(chain))); } private static ChainRegistry buildChainRegistry(ChainsConfig chainsConfig, ComponentRegistry... filters) { ChainRegistry chainRegistry = new ChainRegistry<>(); ChainsModel chainsModel = ChainsModelBuilder.buildFromConfig(chainsConfig); ChainsConfigurer.prepareChainRegistry(chainRegistry, chainsModel, allFiltersWrapped(filters)); removeEmptyChains(chainRegistry); chainRegistry.freeze(); return chainRegistry; } private static void removeEmptyChains(ChainRegistry chainRegistry) { chainRegistry.allComponents().stream() .filter(chain -> chain.components().isEmpty()) .map(Chain::getId) .peek(id -> log.warning("Removing empty filter chain: " + id)) .forEach(chainRegistry::unregister); } @SuppressWarnings("unchecked") private static Object toJDiscChain(Chain chain) { if (chain.components().isEmpty()) throw new IllegalArgumentException("Empty filter chain: " + chain.getId()); checkFilterTypesCompatible(chain); List jdiscFilters = chain.components().stream() .map(filterWrapper -> filterWrapper.filter) .toList(); List wrappedFilters = wrapSecurityFilters(jdiscFilters); Object head = wrappedFilters.get(0); if (wrappedFilters.size() == 1) return head; else if (head instanceof RequestFilter) return RequestFilterChain.newInstance((List) wrappedFilters); else if (head instanceof ResponseFilter) return ResponseFilterChain.newInstance((List) wrappedFilters); throw new IllegalStateException(); } private static List wrapSecurityFilters(List filters) { List aggregatedSecurityFilters = new ArrayList<>(); List wrappedFilters = new ArrayList<>(); for (Object filter : filters) { if (isSecurityFilter(filter)) { aggregatedSecurityFilters.add(filter); } else { if (!aggregatedSecurityFilters.isEmpty()) { wrappedFilters.add(createSecurityChain(aggregatedSecurityFilters)); aggregatedSecurityFilters.clear(); } wrappedFilters.add(filter); } } if (!aggregatedSecurityFilters.isEmpty()) { wrappedFilters.add(createSecurityChain(aggregatedSecurityFilters)); } return wrappedFilters; } private static void checkFilterTypesCompatible(Chain chain) { Set requestFilters = chain.components().stream() .filter(filter -> filter instanceof RequestFilter || filter instanceof SecurityRequestFilter) .map(FilterWrapper::getId) .collect(toSet()); Set responseFilters = chain.components().stream() .filter(filter -> filter instanceof ResponseFilter || filter instanceof SecurityResponseFilter) .map(FilterWrapper::getId) .collect(toSet()); if (!requestFilters.isEmpty() && !responseFilters.isEmpty()) { throw new IllegalArgumentException( String.format( "Can't mix request and response filters in chain %s: request filters: %s, response filters: %s.", chain.getId(), requestFilters, responseFilters)); } } private static ComponentRegistry allFiltersWrapped(ComponentRegistry... registries) { ComponentRegistry wrappedFilters = new ComponentRegistry<>(); for (ComponentRegistry registry : registries) { registry.allComponentsById() .forEach((id, filter) -> wrappedFilters.register(id, new FilterWrapper(id, filter))); } wrappedFilters.freeze(); return wrappedFilters; } private static Object wrapIfSecurityFilter(Object filter) { if (isSecurityFilter(filter)) return createSecurityChain(Collections.singletonList(filter)); return filter; } @SuppressWarnings("unchecked") private static Object createSecurityChain(List filters) { Object head = filters.get(0); if (head instanceof SecurityRequestFilter) return SecurityRequestFilterChain.newInstance((List) filters); else if (head instanceof SecurityResponseFilter) return SecurityResponseFilterChain.newInstance((List) filters); throw new IllegalArgumentException("Unexpected class " + head.getClass()); } private static boolean isSecurityFilter(Object filter) { return filter instanceof SecurityRequestFilter || filter instanceof SecurityResponseFilter; } private static class FilterWrapper extends ChainedComponent { public final Object filter; public final Class filterType; public FilterWrapper(ComponentId id, Object filter) { super(id); this.filter = filter; this.filterType = getFilterType(filter); } private static Class getFilterType(Object filter) { if (filter instanceof RequestFilter) return RequestFilter.class; else if (filter instanceof ResponseFilter) return ResponseFilter.class; else if (filter instanceof SecurityRequestFilter) return SecurityRequestFilter.class; else if (filter instanceof SecurityResponseFilter) return SecurityResponseFilter.class; throw new IllegalArgumentException("Unsupported filter type: " + filter.getClass().getName()); } } }