aboutsummaryrefslogtreecommitdiffstats
path: root/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/systemflags/FlagsClient.java
blob: 2b53b1a32f55d09c49d38495f9a2d46af77c65db (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.controller.restapi.systemflags;

import ai.vespa.util.http.hc4.SslConnectionSocketFactory;
import ai.vespa.util.http.hc4.retry.DelayedConnectionLevelRetryHandler;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.vespa.athenz.api.AthenzIdentity;
import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier;
import com.yahoo.vespa.flags.FlagId;
import com.yahoo.vespa.flags.json.FlagData;
import com.yahoo.vespa.hosted.controller.api.integration.ControllerIdentityProvider;
import com.yahoo.vespa.hosted.controller.api.systemflags.v1.FlagsTarget;
import com.yahoo.vespa.hosted.controller.api.systemflags.v1.wire.WireErrorResponse;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.NameValuePair;
import org.apache.http.client.ResponseHandler;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpDelete;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import static java.util.stream.Collectors.toSet;

/**
 * A client for /flags/v1 rest api on configserver and controller.
 *
 * @author bjorncs
 */
class FlagsClient {

    private static final String FLAGS_V1_PATH = "/flags/v1";

    private static final ObjectMapper mapper = new ObjectMapper();

    private final CloseableHttpClient client;

    FlagsClient(ControllerIdentityProvider identityProvider, Set<FlagsTarget> targets) {
        this.client = createClient(identityProvider, targets);
    }

    List<FlagData> listFlagData(FlagsTarget target) throws FlagsException, UncheckedIOException {
        HttpGet request = new HttpGet(createUri(target, "/data", List.of(new BasicNameValuePair("recursive", "true"))));
        return executeRequest(request, response -> {
            verifySuccess(response, null);
            return FlagData.deserializeList(EntityUtils.toByteArray(response.getEntity()));
        });
    }

    List<FlagId> listDefinedFlags(FlagsTarget target) {
        HttpGet request = new HttpGet(createUri(target, "/defined", List.of()));
        return executeRequest(request, response -> {
            verifySuccess(response, null);
            JsonNode json = mapper.readTree(response.getEntity().getContent());
            List<FlagId> flagIds = new ArrayList<>();
            json.fieldNames().forEachRemaining(fieldName -> flagIds.add(new FlagId(fieldName)));
            return flagIds;
        });
    }

    void putFlagData(FlagsTarget target, FlagData flagData) throws FlagsException, UncheckedIOException {
        HttpPut request = new HttpPut(createUri(target, "/data/" + flagData.id().toString(), List.of()));
        request.setEntity(jsonContent(flagData.serializeToJson()));
        executeRequest(request, response -> {
            verifySuccess(response, flagData.id());
            return null;
        });
    }

    void deleteFlagData(FlagsTarget target, FlagId flagId) throws FlagsException, UncheckedIOException {
        HttpDelete request = new HttpDelete(createUri(target, "/data/" + flagId.toString(), List.of(new BasicNameValuePair("force", "true"))));
        executeRequest(request, response -> {
            verifySuccess(response, flagId);
            return null;
        });
    }

    private static CloseableHttpClient createClient(ControllerIdentityProvider identityProvider, Set<FlagsTarget> targets) {
        DelayedConnectionLevelRetryHandler retryHandler = DelayedConnectionLevelRetryHandler.Builder
                .withExponentialBackoff(Duration.ofSeconds(1), Duration.ofSeconds(20), 5)
                .build();

        return HttpClientBuilder.create()
                .setUserAgent("controller-flags-v1-client")
                .setSSLSocketFactory(SslConnectionSocketFactory.of(
                        identityProvider.getConfigServerSslSocketFactory(), new FlagTargetsHostnameVerifier(targets)))
                .setDefaultRequestConfig(RequestConfig.custom()
                                                 .setConnectTimeout((int) Duration.ofSeconds(10).toMillis())
                                                 .setConnectionRequestTimeout((int) Duration.ofSeconds(10).toMillis())
                                                 .setSocketTimeout((int) Duration.ofSeconds(20).toMillis())
                                                 .build())
                .setMaxConnPerRoute(2)
                .setMaxConnTotal(100)
                .setRetryHandler(retryHandler)
                .build();
    }

    private <T> T executeRequest(HttpUriRequest request, ResponseHandler<T> handler) {
        try {
            return client.execute(request, handler);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private static URI createUri(FlagsTarget target, String subPath, List<NameValuePair> queryParams) {
        try {
            return new URIBuilder(target.endpoint())
                    .setPath(FLAGS_V1_PATH + subPath)
                    .setParameters(queryParams)
                    .build();
        } catch (URISyntaxException e) {
            throw new RuntimeException(e); // should never happen
        }
    }

    private static void verifySuccess(HttpResponse response, FlagId flagId) throws IOException {
        if (!success(response)) {
            throw createFlagsException(response, flagId);
        }
    }

    private static FlagsException createFlagsException(HttpResponse response, FlagId flagId) throws IOException {
        HttpEntity entity = response.getEntity();
        String content = EntityUtils.toString(entity);
        int statusCode = response.getStatusLine().getStatusCode();
        if (ContentType.get(entity).getMimeType().equals(ContentType.APPLICATION_JSON.getMimeType())) {
            WireErrorResponse error = mapper.readValue(content, WireErrorResponse.class);
            return new FlagsException(statusCode, flagId, error.errorCode, error.message);
        } else {
            return new FlagsException(statusCode, flagId, null, content);
        }
    }

    private static boolean success(HttpResponse response) {
        return response.getStatusLine().getStatusCode() == HttpStatus.SC_OK;
    }

    private static StringEntity jsonContent(String json) {
        return new StringEntity(json, ContentType.APPLICATION_JSON);
    }

    private static class FlagTargetsHostnameVerifier implements HostnameVerifier {

        private final AthenzIdentityVerifier athenzVerifier;

        FlagTargetsHostnameVerifier(Set<FlagsTarget> targets) {
            this.athenzVerifier = createAthenzIdentityVerifier(targets);
        }

        private static AthenzIdentityVerifier createAthenzIdentityVerifier(Set<FlagsTarget> targets) {
            Set<AthenzIdentity> identities = targets.stream()
                    .flatMap(target -> target.athenzHttpsIdentity().stream())
                    .collect(toSet());
            return new AthenzIdentityVerifier(identities);
        }

        @Override
        public boolean verify(String hostname, SSLSession session) {
            return "localhost".equals(hostname) /* for controllers */ || athenzVerifier.verify(hostname, session);
        }
    }

    static class FlagsException extends RuntimeException {

        private FlagsException(int statusCode, FlagId flagId, String errorCode, String errorMessage) {
            super(createErrorMessage(statusCode, flagId, errorCode, errorMessage));
        }

        private static String createErrorMessage(int statusCode, FlagId flagId, String errorCode, String errorMessage) {
            StringBuilder builder = new StringBuilder().append("Received ").append(statusCode);
            if (errorCode != null) {
                builder.append('/').append(errorCode);
            }
            if (flagId != null) {
                builder.append(" for flag '").append(flagId).append("'");
            }
            return builder.append(": ").append(errorMessage).toString();
        }
    }
}