aboutsummaryrefslogtreecommitdiffstats
path: root/container-disc/src/main/java/com/yahoo/container/jdisc/FilterBindingsProvider.java
blob: 6ec4f27f3275e35a8f70a935a8ea5ff7ac4d18d9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc;

import com.yahoo.component.annotation.Inject;
import com.yahoo.component.ComponentId;
import com.yahoo.component.ComponentSpecification;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.container.di.componentgraph.Provider;
import com.yahoo.container.http.filter.FilterChainRepository;
import com.yahoo.jdisc.http.ServerConfig;
import com.yahoo.jdisc.http.filter.RequestFilter;
import com.yahoo.jdisc.http.filter.ResponseFilter;
import com.yahoo.jdisc.http.filter.SecurityRequestFilter;
import com.yahoo.jdisc.http.filter.SecurityRequestFilterChain;
import com.yahoo.jdisc.http.server.jetty.FilterBindings;

import java.util.HashSet;
import java.util.Set;

/**
 * Provides filter bindings based on vespa config.
 *
 * @author Oyvind Bakksjo
 * @author bjorncs
 */
public class FilterBindingsProvider implements Provider<FilterBindings> {

    private static final ComponentId SEARCH_SERVER_COMPONENT_ID = ComponentId.fromString("SearchServer");

    private final FilterBindings filterBindings;

    @Inject
    public FilterBindingsProvider(ComponentId componentId,
                                  ServerConfig config,
                                  FilterChainRepository filterChainRepository,
                                  ComponentRegistry<SecurityRequestFilter> legacyRequestFilters) {
        try {
            FilterBindings.Builder builder = new FilterBindings.Builder();
            configureLegacyFilters(builder, componentId, legacyRequestFilters);
            configureFilters(builder, config, filterChainRepository);
            builder.setStrictFiltering(config.strictFiltering());
            this.filterBindings = builder.build();
        } catch (Exception e) {
            throw new RuntimeException(
                    "Invalid config for http server '" + componentId.getNamespace() + "': " + e.getMessage(), e);
        }
    }

    // TODO(gjoranv): remove
    private static void configureLegacyFilters(
            FilterBindings.Builder builder,
            ComponentId id,
            ComponentRegistry<SecurityRequestFilter> legacyRequestFilters) {
        ComponentId serverName = id.getNamespace();
        if (SEARCH_SERVER_COMPONENT_ID.equals(serverName) && !legacyRequestFilters.allComponents().isEmpty()) {
            String filterId = "legacy-filters";
            builder.addRequestFilter(filterId, SecurityRequestFilterChain.newInstance(legacyRequestFilters.allComponents()));
            builder.addRequestFilterBinding(filterId, "http://*/*");
        }
    }

    private static void configureFilters(
            FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) {
        addFilterInstances(builder, config, filterRepository);
        addFilterBindings(builder, config, filterRepository);
        addPortDefaultFilters(builder, config, filterRepository);
    }

    private static void addFilterInstances(
            FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) {
        Set<String> filterIds = new HashSet<>();
        config.filter().forEach(filterBinding -> filterIds.add(filterBinding.id()));
        config.defaultFilters().forEach(defaultFilter -> filterIds.add(defaultFilter.filterId()));
        for (String filterId : filterIds) {
            Object filterInstance = getFilterInstance(filterRepository, filterId);
            if (filterInstance instanceof RequestFilter && filterInstance instanceof ResponseFilter) {
                throw new IllegalArgumentException("The filter " + filterInstance.getClass().getName()
                        + " is unsupported since it's both a RequestFilter and a ResponseFilter.");
            } else if (filterInstance instanceof RequestFilter) {
                builder.addRequestFilter(filterId, (RequestFilter)filterInstance);
            } else if (filterInstance instanceof ResponseFilter) {
                builder.addResponseFilter(filterId, (ResponseFilter)filterInstance);
            } else if (filterInstance == null) {
                throw new IllegalArgumentException("No http filter with id " + filterId);
            } else {
                throw new IllegalArgumentException("Unknown filter type: " + filterInstance.getClass().getName());
            }
        }
    }

    private static void addFilterBindings(
            FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) {
        for (ServerConfig.Filter filterBinding : config.filter()) {
            if (isRequestFilter(filterRepository, filterBinding.id())) {
                builder.addRequestFilterBinding(filterBinding.id(), filterBinding.binding());
            } else {
                builder.addResponseFilterBinding(filterBinding.id(), filterBinding.binding());
            }
        }
    }

    private static void addPortDefaultFilters(
            FilterBindings.Builder builder, ServerConfig config, FilterChainRepository filterRepository) {
        for (ServerConfig.DefaultFilters defaultFilter : config.defaultFilters()) {
            if (isRequestFilter(filterRepository, defaultFilter.filterId())) {
                builder.setRequestFilterDefaultForPort(defaultFilter.filterId(), defaultFilter.localPort());
            } else {
                builder.setResponseFilterDefaultForPort(defaultFilter.filterId(), defaultFilter.localPort());
            }
        }
    }

    private static boolean isRequestFilter(FilterChainRepository filterRepository, String filterId) {
        return getFilterInstance(filterRepository, filterId) instanceof RequestFilter;
    }

    private static Object getFilterInstance(FilterChainRepository filterRepository, String filterId) {
        return filterRepository.getFilter(ComponentSpecification.fromString(filterId));
    }

    @Override public FilterBindings get() { return filterBindings; }

    @Override public void deconstruct() {}

}