summaryrefslogtreecommitdiffstats
path: root/vespa-http-client
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-01-06 16:48:29 +0100
committerBjørn Christian Seime <bjorncs@yahoo-inc.com>2017-01-09 14:42:41 +0100
commit183fb672d3e2fd13d84226baaf49d5e95e377379 (patch)
tree29e9ae91d4e29f8b252508906268f331ac0d6891 /vespa-http-client
parent55d84c60e183beaef9ae9206b2e9c541780cbb8e (diff)
Add support for dynamic headers
Diffstat (limited to 'vespa-http-client')
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java42
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java16
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java17
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java48
4 files changed, 113 insertions, 10 deletions
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
index a3896808177..a1863262ba0 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
@@ -4,15 +4,14 @@ package com.yahoo.vespa.http.client.config;
import com.google.common.annotations.Beta;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
-import com.yahoo.vespa.http.client.Session;
-
import net.jcip.annotations.Immutable;
import javax.net.ssl.SSLContext;
-
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.TimeUnit;
/**
@@ -36,6 +35,7 @@ public final class ConnectionParams {
private SSLContext sslContext = null;
private long connectionTimeout = TimeUnit.SECONDS.toMillis(60);
private final Multimap<String, String> headers = ArrayListMultimap.create();
+ private final Map<String, HeaderProvider> headerProviders = new HashMap<>();
private int numPersistentConnectionsPerEndpoint = 8;
private String proxyHost = null;
private int proxyPort = 8080;
@@ -73,6 +73,24 @@ public final class ConnectionParams {
}
/**
+ * Adds a header provider for dynamic headers; headers where the value may change during a feeding session
+ * (e.g. security tokens with limited life time). Only one {@link HeaderProvider} is allowed for a given header name.
+ *
+ * @param provider A provider for a dynamic header
+ * @return pointer to builder.
+ * @throws IllegalArgumentException if a provider is already registered for the given header name
+ */
+ public Builder addDynamicHeader(String headerName, HeaderProvider provider) {
+ Objects.requireNonNull(headerName, "Header name cannot be null");
+ Objects.requireNonNull(provider, "Header provider cannot be null");
+ if (headerProviders.containsKey(headerName)) {
+ throw new IllegalArgumentException("Provider already registered for name '" + headerName + "'");
+ }
+ headerProviders.put(headerName, provider);
+ return this;
+ }
+
+ /**
* The number of connections between the http client and the gateways. A very low number can result
* in the network not fully utilized and the round-trip time can be a limiting factor. A low number
* can cause skew in distribution of load between gateways. A too high number will cause
@@ -203,6 +221,7 @@ public final class ConnectionParams {
sslContext,
connectionTimeout,
headers,
+ headerProviders,
numPersistentConnectionsPerEndpoint,
proxyHost,
proxyPort,
@@ -254,6 +273,7 @@ public final class ConnectionParams {
private final SSLContext sslContext;
private final long connectionTimeout;
private final Multimap<String, String> headers = ArrayListMultimap.create();
+ private final Map<String, HeaderProvider> headerProviders = new HashMap<>();
private final int numPersistentConnectionsPerEndpoint;
private final String proxyHost;
private final int proxyPort;
@@ -270,6 +290,7 @@ public final class ConnectionParams {
SSLContext sslContext,
long connectionTimeout,
Multimap<String, String> headers,
+ Map<String, HeaderProvider> headerProviders,
int numPersistentConnectionsPerEndpoint,
String proxyHost,
int proxyPort,
@@ -284,6 +305,7 @@ public final class ConnectionParams {
this.sslContext = sslContext;
this.connectionTimeout = connectionTimeout;
this.headers.putAll(headers);
+ this.headerProviders.putAll(headerProviders);
this.numPersistentConnectionsPerEndpoint = numPersistentConnectionsPerEndpoint;
this.proxyHost = proxyHost;
this.proxyPort = proxyPort;
@@ -305,6 +327,10 @@ public final class ConnectionParams {
return Collections.unmodifiableCollection(headers.entries());
}
+ public Map<String, HeaderProvider> getDynamicHeaders() {
+ return Collections.unmodifiableMap(headerProviders);
+ }
+
public int getNumPersistentConnectionsPerEndpoint() {
return numPersistentConnectionsPerEndpoint;
}
@@ -347,4 +373,14 @@ public final class ConnectionParams {
return printTraceToStdErr;
}
+ /**
+ * A header provider that provides a header value. {@link #getHeaderValue()} is called each time a new HTTP request
+ * is constructed by {@link com.yahoo.vespa.http.client.FeedClient}.
+ *
+ * Important: The implementation of {@link #getHeaderValue()} must be thread-safe!
+ */
+ public interface HeaderProvider {
+ String getHeaderValue();
+ }
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
index f3d751f67c7..6fae0765e14 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
@@ -16,6 +16,8 @@ import org.apache.http.StatusLine;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
+import org.apache.http.config.Registry;
+import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
@@ -23,9 +25,6 @@ import org.apache.http.entity.InputStreamEntity;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
-import org.apache.http.config.Registry;
-import org.apache.http.config.RegistryBuilder;
-
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -36,13 +35,13 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.GZIPOutputStream;
-import java.util.UUID;
-
/**
* @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
* @since 5.1.20
@@ -228,6 +227,13 @@ class ApacheGatewayConnection implements GatewayConnection {
for (Map.Entry<String, String> extraHeader : connectionParams.getHeaders()) {
httpPost.addHeader(extraHeader.getKey(), extraHeader.getValue());
}
+ connectionParams.getDynamicHeaders().forEach((headerName, provider) -> {
+ String headerValue = Objects.requireNonNull(
+ provider.getHeaderValue(),
+ provider.getClass().getName() + ".getHeader() returned null as header value!");
+ httpPost.addHeader(headerName, headerValue);
+ });
+
if (useCompression) {
httpPost.setHeader("Content-Encoding", "gzip");
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java
index 49ffabbf1d0..cb9270fdd3f 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java
@@ -7,12 +7,11 @@ import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.Iterator;
import java.util.Map;
-import java.util.concurrent.TimeUnit;
import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsNull.nullValue;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
/**
@@ -57,4 +56,18 @@ public class ConnectionParamsTest {
assertThat(header3.getValue(), equalTo("Apple"));
}
+ @Test
+ public void header_providers_are_registered() {
+ ConnectionParams.HeaderProvider dummyProvider1 = () -> "fooValue";
+ ConnectionParams.HeaderProvider dummyProvider2 = () -> "barValue";
+ ConnectionParams params = new ConnectionParams.Builder()
+ .addDynamicHeader("foo", dummyProvider1)
+ .addDynamicHeader("bar", dummyProvider2)
+ .build();
+ Map<String, ConnectionParams.HeaderProvider> headerProviders = params.getDynamicHeaders();
+ assertEquals(2, headerProviders.size());
+ assertEquals(dummyProvider1, headerProviders.get("foo"));
+ assertEquals(dummyProvider2, headerProviders.get("bar"));
+ }
+
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
index 48fe0461eab..456e5184f6e 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
@@ -32,11 +32,14 @@ import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.stub;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -244,6 +247,51 @@ public class ApacheGatewayConnectionTest {
assertTrue(verifyContentSentLatch.await(10, TimeUnit.SECONDS));
}
+ @Test
+ public void dynamic_headers_are_added_to_the_response() throws IOException, ServerResponseException, InterruptedException {
+ ConnectionParams.HeaderProvider headerProvider = mock(ConnectionParams.HeaderProvider.class);
+ when(headerProvider.getHeaderValue())
+ .thenReturn("v1")
+ .thenReturn("v2")
+ .thenReturn("v3");
+
+ ConnectionParams connectionParams = new ConnectionParams.Builder()
+ .addDynamicHeader("foo", headerProvider)
+ .build();
+
+ CountDownLatch verifyContentSentLatch = new CountDownLatch(1);
+
+ AtomicInteger counter = new AtomicInteger(1);
+ ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
+ Header[] fooHeader = post.getHeaders("foo");
+ assertEquals(1, fooHeader.length);
+ assertEquals("foo", fooHeader[0].getName());
+ assertEquals("v" + counter.getAndIncrement(), fooHeader[0].getValue());
+ verifyContentSentLatch.countDown();
+ return httpResponse("clientId", "3");
+
+ });
+
+ ApacheGatewayConnection apacheGatewayConnection =
+ new ApacheGatewayConnection(
+ Endpoint.create("hostname", 666, false),
+ new FeedParams.Builder().build(),
+ "",
+ connectionParams,
+ mockFactory,
+ "clientId");
+ apacheGatewayConnection.connect();
+ apacheGatewayConnection.handshake();
+
+ List<Document> documents = new ArrayList<>();
+ documents.add(createDoc("42", "content", true));
+ apacheGatewayConnection.writeOperations(documents);
+ apacheGatewayConnection.writeOperations(documents);
+ assertTrue(verifyContentSentLatch.await(10, TimeUnit.SECONDS));
+
+ verify(headerProvider, times(3)).getHeaderValue(); // 1x connect(), 2x writeOperations()
+ }
+
private static ApacheGatewayConnection.HttpClientFactory mockHttpClientFactory(HttpExecuteMock httpExecuteMock) throws IOException {
ApacheGatewayConnection.HttpClientFactory mockFactory =
mock(ApacheGatewayConnection.HttpClientFactory.class);