summaryrefslogtreecommitdiffstats
path: root/container-disc/src
diff options
context:
space:
mode:
Diffstat (limited to 'container-disc/src')
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/ContainerThreadFactory.java2
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java8
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java16
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java51
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java93
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java228
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzService.java30
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java23
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/IdentityDocumentService.java (renamed from container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/ServiceProviderApi.java)36
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceIdentity.java46
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRefreshInformation.java5
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRegisterInformation.java5
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java7
-rw-r--r--container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java206
14 files changed, 645 insertions, 111 deletions
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/ContainerThreadFactory.java b/container-disc/src/main/java/com/yahoo/container/jdisc/ContainerThreadFactory.java
index 379116a5d94..50798a82b60 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/ContainerThreadFactory.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/ContainerThreadFactory.java
@@ -8,7 +8,7 @@ import com.yahoo.jdisc.application.MetricConsumer;
import java.util.concurrent.ThreadFactory;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen Hult</a>
+ * @author Simon Thoresen Hult
*/
public class ContainerThreadFactory implements ThreadFactory {
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
index 19e04e0ae01..033b396bc9b 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
@@ -5,9 +5,7 @@ package com.yahoo.container.jdisc.athenz;
* @author mortent
*/
public interface AthenzIdentityProvider {
-
- String getNToken();
- String getX509Cert();
- String domain();
- String service();
+ String getNToken() throws AthenzIdentityProviderException;
+ String getDomain();
+ String getService();
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java
new file mode 100644
index 00000000000..fd5839bfc45
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java
@@ -0,0 +1,16 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.jdisc.athenz;
+
+/**
+ * @author bjorncs
+ */
+public class AthenzIdentityProviderException extends RuntimeException {
+
+ public AthenzIdentityProviderException(String message) {
+ super(message);
+ }
+
+ public AthenzIdentityProviderException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java
new file mode 100644
index 00000000000..790a7c54333
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java
@@ -0,0 +1,51 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.jdisc.athenz.impl;
+
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.time.Instant;
+
+/**
+ * @author bjorncs
+ */
+class AthenzCredentials {
+
+ private final String nToken;
+ private final X509Certificate certificate;
+ private final KeyPair keyPair;
+ private final SignedIdentityDocument identityDocument;
+ private final Instant createdAt;
+
+ AthenzCredentials(String nToken,
+ X509Certificate certificate,
+ KeyPair keyPair,
+ SignedIdentityDocument identityDocument,
+ Instant createdAt) {
+ this.nToken = nToken;
+ this.certificate = certificate;
+ this.keyPair = keyPair;
+ this.identityDocument = identityDocument;
+ this.createdAt = createdAt;
+ }
+
+ String getNToken() {
+ return nToken;
+ }
+
+ X509Certificate getCertificate() {
+ return certificate;
+ }
+
+ KeyPair getKeyPair() {
+ return keyPair;
+ }
+
+ SignedIdentityDocument getIdentityDocument() {
+ return identityDocument;
+ }
+
+ Instant getCreatedAt() {
+ return createdAt;
+ }
+
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java
new file mode 100644
index 00000000000..5786eb9e398
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java
@@ -0,0 +1,93 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.jdisc.athenz.impl;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.container.core.identity.IdentityConfig;
+import org.bouncycastle.pkcs.PKCS10CertificationRequest;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.time.Clock;
+
+/**
+ * @author bjorncs
+ */
+class AthenzCredentialsService {
+
+ private static final ObjectMapper mapper = new ObjectMapper();
+
+ private final IdentityConfig identityConfig;
+ private final IdentityDocumentService identityDocumentService;
+ private final AthenzService athenzService;
+ private final Clock clock;
+
+ AthenzCredentialsService(IdentityConfig identityConfig,
+ IdentityDocumentService identityDocumentService,
+ AthenzService athenzService,
+ Clock clock) {
+ this.identityConfig = identityConfig;
+ this.identityDocumentService = identityDocumentService;
+ this.athenzService = athenzService;
+ this.clock = clock;
+ }
+
+ AthenzCredentials registerInstance() {
+ KeyPair keyPair = CryptoUtils.createKeyPair();
+ String rawDocument = identityDocumentService.getSignedIdentityDocument();
+ SignedIdentityDocument document = parseSignedIdentityDocument(rawDocument);
+ PKCS10CertificationRequest csr = CryptoUtils.createCSR(identityConfig.domain(),
+ identityConfig.service(),
+ document.dnsSuffix,
+ document.providerUniqueId,
+ keyPair);
+ InstanceRegisterInformation instanceRegisterInformation =
+ new InstanceRegisterInformation(document.providerService,
+ identityConfig.domain(),
+ identityConfig.service(),
+ rawDocument,
+ CryptoUtils.toPem(csr));
+ InstanceIdentity instanceIdentity = athenzService.sendInstanceRegisterRequest(instanceRegisterInformation,
+ document.ztsEndpoint);
+ return toAthenzCredentials(instanceIdentity, keyPair, document);
+ }
+
+ AthenzCredentials updateCredentials(AthenzCredentials currentCredentials) {
+ SignedIdentityDocument document = currentCredentials.getIdentityDocument();
+ KeyPair newKeyPair = CryptoUtils.createKeyPair();
+ PKCS10CertificationRequest csr = CryptoUtils.createCSR(identityConfig.domain(),
+ identityConfig.service(),
+ document.dnsSuffix,
+ document.providerUniqueId,
+ newKeyPair);
+ InstanceRefreshInformation refreshInfo = new InstanceRefreshInformation(CryptoUtils.toPem(csr));
+ InstanceIdentity instanceIdentity =
+ athenzService.sendInstanceRefreshRequest(document.providerService,
+ identityConfig.domain(),
+ identityConfig.service(),
+ document.providerUniqueId,
+ refreshInfo,
+ document.ztsEndpoint,
+ currentCredentials.getCertificate(),
+ currentCredentials.getKeyPair().getPrivate());
+ return toAthenzCredentials(instanceIdentity, newKeyPair, document);
+ }
+
+ private AthenzCredentials toAthenzCredentials(InstanceIdentity instanceIdentity,
+ KeyPair keyPair,
+ SignedIdentityDocument identityDocument) {
+ X509Certificate certificate = instanceIdentity.getX509Certificate();
+ String serviceToken = instanceIdentity.getServiceToken();
+ return new AthenzCredentials(serviceToken, certificate, keyPair, identityDocument, clock.instant());
+ }
+
+ private static SignedIdentityDocument parseSignedIdentityDocument(String rawDocument) {
+ try {
+ return mapper.readValue(rawDocument, SignedIdentityDocument.class);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
index d2c914fc209..356780a0900 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
@@ -1,75 +1,237 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc.athenz.impl;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.container.core.identity.IdentityConfig;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
+import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
+import com.yahoo.log.LogLevel;
-import java.io.IOException;
-import java.security.KeyPair;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.logging.Logger;
/**
* @author mortent
+ * @author bjorncs
*/
public final class AthenzIdentityProviderImpl extends AbstractComponent implements AthenzIdentityProvider {
- private final ObjectMapper objectMapper = new ObjectMapper();
+ private static final Logger log = Logger.getLogger(AthenzIdentityProviderImpl.class.getName());
- private InstanceIdentity instanceIdentity;
+ // TODO Make some of these values configurable through config. Match requested expiration of register/update requests.
+ // TODO These should match the requested expiration
+ static final Duration EXPIRES_AFTER = Duration.ofDays(1);
+ static final Duration EXPIRATION_MARGIN = Duration.ofMinutes(30);
+ static final Duration INITIAL_WAIT_NTOKEN = Duration.ofMinutes(5);
+ static final Duration UPDATE_PERIOD = EXPIRES_AFTER.dividedBy(3);
+ static final Duration REDUCED_UPDATE_PERIOD = Duration.ofMinutes(30);
+ static final Duration INITIAL_BACKOFF_DELAY = Duration.ofMinutes(4);
+ static final Duration MAX_REGISTER_BACKOFF_DELAY = Duration.ofHours(1);
+ static final int BACKOFF_DELAY_MULTIPLIER = 2;
+ static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90);
- private final String dnsSuffix;
- private final String providerUniqueId;
+
+ static final String REGISTER_INSTANCE_TAG = "register-instance";
+ static final String UPDATE_CREDENTIALS_TAG = "update-credentials";
+ static final String TIMEOUT_INITIAL_WAIT_TAG = "timeout-initial-wait";
+
+
+ private final AtomicReference<AthenzCredentials> credentials = new AtomicReference<>();
+ private final AtomicReference<Throwable> lastThrowable = new AtomicReference<>();
+ private final CountDownLatch credentialsRetrievedSignal = new CountDownLatch(1);
+ private final AthenzCredentialsService athenzCredentialsService;
+ private final Scheduler scheduler;
+ private final Clock clock;
private final String domain;
private final String service;
@Inject
- public AthenzIdentityProviderImpl(IdentityConfig config) throws IOException {
- this(config, new ServiceProviderApi(config.loadBalancerAddress()), new AthenzService());
+ public AthenzIdentityProviderImpl(IdentityConfig config) {
+ this(config,
+ new AthenzCredentialsService(config,
+ new IdentityDocumentService(config.loadBalancerAddress()),
+ new AthenzService(),
+ Clock.systemUTC()),
+ new ThreadPoolScheduler(),
+ Clock.systemUTC());
}
// Test only
AthenzIdentityProviderImpl(IdentityConfig config,
- ServiceProviderApi serviceProviderApi,
- AthenzService athenzService) throws IOException {
- KeyPair keyPair = CryptoUtils.createKeyPair();
+ AthenzCredentialsService athenzCredentialsService,
+ Scheduler scheduler,
+ Clock clock) {
+ this.athenzCredentialsService = athenzCredentialsService;
+ this.scheduler = scheduler;
+ this.clock = clock;
this.domain = config.domain();
this.service = config.service();
- String rawDocument = serviceProviderApi.getSignedIdentityDocument();
- SignedIdentityDocument document = objectMapper.readValue(rawDocument, SignedIdentityDocument.class);
- this.dnsSuffix = document.dnsSuffix;
- this.providerUniqueId = document.providerUniqueId;
-
- InstanceRegisterInformation instanceRegisterInformation = new InstanceRegisterInformation(
- document.providerService,
- this.domain,
- this.service,
- rawDocument,
- CryptoUtils.toPem(CryptoUtils.createCSR(domain, service, dnsSuffix, providerUniqueId, keyPair)),
- true
- );
- instanceIdentity = athenzService.sendInstanceRegisterRequest( instanceRegisterInformation, document.ztsEndpoint);
+ scheduler.submit(new RegisterInstanceTask());
+ scheduler.schedule(new TimeoutInitialWaitTask(), INITIAL_WAIT_NTOKEN);
}
@Override
public String getNToken() {
- return instanceIdentity.getServiceToken();
+ try {
+ credentialsRetrievedSignal.await();
+ AthenzCredentials credentialsSnapshot = credentials.get();
+ if (credentialsSnapshot == null) {
+ throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", lastThrowable.get());
+ }
+ if (isExpired(credentialsSnapshot)) {
+ throw new AthenzIdentityProviderException("Athenz credentials are expired", lastThrowable.get());
+ }
+ return credentialsSnapshot.getNToken();
+ } catch (InterruptedException e) {
+ throw new AthenzIdentityProviderException("Failed to register instance credentials", lastThrowable.get());
+ }
}
@Override
- public String getX509Cert() {
- return instanceIdentity.getX509Certificate();
+ public String getDomain() {
+ return domain;
}
@Override
- public String domain() {
- return domain;
+ public String getService() {
+ return service;
}
@Override
- public String service() {
- return service;
+ public void deconstruct() {
+ scheduler.shutdown(AWAIT_TERMINTATION_TIMEOUT);
+ }
+
+ private boolean isExpired(AthenzCredentials credentials) {
+ return clock.instant().isAfter(getExpirationTime(credentials));
+ }
+
+ private static Instant getExpirationTime(AthenzCredentials credentials) {
+ return credentials.getCreatedAt().plus(EXPIRES_AFTER).minus(EXPIRATION_MARGIN);
+ }
+
+ private class RegisterInstanceTask implements RunnableWithTag {
+
+ private final Duration backoffDelay;
+
+ RegisterInstanceTask() {
+ this(INITIAL_BACKOFF_DELAY);
+ }
+
+ RegisterInstanceTask(Duration backoffDelay) {
+ this.backoffDelay = backoffDelay;
+ }
+
+ @Override
+ public void run() {
+ try {
+ credentials.set(athenzCredentialsService.registerInstance());
+ credentialsRetrievedSignal.countDown();
+ scheduler.schedule(new UpdateCredentialsTask(), UPDATE_PERIOD);
+ } catch (Throwable t) {
+ log.log(LogLevel.ERROR, "Failed to register instance: " + t.getMessage(), t);
+ lastThrowable.set(t);
+ Duration nextBackoffDelay = backoffDelay.multipliedBy(BACKOFF_DELAY_MULTIPLIER);
+ if (nextBackoffDelay.compareTo(MAX_REGISTER_BACKOFF_DELAY) > 0) {
+ nextBackoffDelay = MAX_REGISTER_BACKOFF_DELAY;
+ }
+ scheduler.schedule(new RegisterInstanceTask(nextBackoffDelay), backoffDelay);
+ }
+ }
+
+ @Override
+ public String tag() {
+ return REGISTER_INSTANCE_TAG;
+ }
+ }
+
+ private class UpdateCredentialsTask implements RunnableWithTag {
+ @Override
+ public void run() {
+ AthenzCredentials currentCredentials = credentials.get();
+ try {
+ AthenzCredentials newCredentials = isExpired(currentCredentials)
+ ? athenzCredentialsService.registerInstance()
+ : athenzCredentialsService.updateCredentials(currentCredentials);
+ credentials.set(newCredentials);
+ scheduler.schedule(new UpdateCredentialsTask(), UPDATE_PERIOD);
+ } catch (Throwable t) {
+ log.log(LogLevel.WARNING, "Failed to update credentials: " + t.getMessage(), t);
+ lastThrowable.set(t);
+ Duration timeToExpiration = Duration.between(clock.instant(), getExpirationTime(currentCredentials));
+ // NOTE: Update period might be after timeToExpiration, still we do not want to DDoS Athenz.
+ Duration updatePeriod =
+ timeToExpiration.compareTo(UPDATE_PERIOD) > 0 ? UPDATE_PERIOD : REDUCED_UPDATE_PERIOD;
+ scheduler.schedule(new UpdateCredentialsTask(), updatePeriod);
+ }
+ }
+
+ @Override
+ public String tag() {
+ return UPDATE_CREDENTIALS_TAG;
+ }
+ }
+
+ private class TimeoutInitialWaitTask implements RunnableWithTag {
+ @Override
+ public void run() {
+ credentialsRetrievedSignal.countDown();
+ }
+
+ @Override
+ public String tag() {
+ return TIMEOUT_INITIAL_WAIT_TAG;
+ }
}
+
+ private static class ThreadPoolScheduler implements Scheduler {
+
+ private static final Logger log = Logger.getLogger(ThreadPoolScheduler.class.getName());
+
+ private final ScheduledExecutorService executor = Executors.newScheduledThreadPool(0);
+
+ @Override
+ public void schedule(RunnableWithTag runnable, Duration delay) {
+ log.log(LogLevel.FINE, String.format("Scheduling task '%s' in '%s'", runnable.tag(), delay));
+ executor.schedule(runnable, delay.getSeconds(), TimeUnit.SECONDS);
+ }
+
+ @Override
+ public void submit(RunnableWithTag runnable) {
+ log.log(LogLevel.FINE, String.format("Scheduling task '%s' now", runnable.tag()));
+ executor.submit(runnable);
+ }
+
+ @Override
+ public void shutdown(Duration timeout) {
+ try {
+ executor.shutdownNow();
+ executor.awaitTermination(AWAIT_TERMINTATION_TIMEOUT.getSeconds(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ public interface Scheduler {
+ void schedule(RunnableWithTag runnable, Duration delay);
+ default void submit(RunnableWithTag runnable) { schedule(runnable, Duration.ZERO); }
+ default void shutdown(Duration timeout) {}
+ }
+
+ public interface RunnableWithTag extends Runnable {
+
+ String tag();
+ }
+
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzService.java
index dc1f8956def..898f90e3438 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzService.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzService.java
@@ -3,6 +3,7 @@ package com.yahoo.container.jdisc.athenz.impl;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.http.client.HttpRequestRetryHandler;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
@@ -10,6 +11,7 @@ import org.apache.http.conn.ssl.SSLContextBuilder;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.DefaultHttpRequestRetryHandler;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.eclipse.jetty.http.HttpStatus;
@@ -17,6 +19,7 @@ import org.eclipse.jetty.http.HttpStatus;
import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.UncheckedIOException;
+import java.net.URI;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
@@ -33,18 +36,19 @@ import java.security.cert.X509Certificate;
*/
public class AthenzService {
- private static final String INSTANCE_API_PATH = "zts/v1/instance";
+ private static final String INSTANCE_API_PATH = "/zts/v1/instance";
private final ObjectMapper objectMapper = new ObjectMapper();
+ private final HttpRequestRetryHandler retryHandler = new DefaultHttpRequestRetryHandler(3, /*requestSentRetryEnabled*/true);
/**
* Send instance register request to ZTS, get InstanceIdentity
*/
public InstanceIdentity sendInstanceRegisterRequest(InstanceRegisterInformation instanceRegisterInformation,
- String ztsEndpoint) {
- try(CloseableHttpClient client = HttpClientBuilder.create().build()) {
+ URI uri) {
+ try(CloseableHttpClient client = HttpClientBuilder.create().setRetryHandler(retryHandler).build()) {
HttpUriRequest postRequest = RequestBuilder.post()
- .setUri(ztsEndpoint + INSTANCE_API_PATH)
+ .setUri(uri.resolve(INSTANCE_API_PATH))
.setEntity(toJsonStringEntity(instanceRegisterInformation))
.build();
return getInstanceIdentity(client, postRequest);
@@ -58,13 +62,16 @@ public class AthenzService {
String instanceServiceName,
String instanceId,
InstanceRefreshInformation instanceRefreshInformation,
- String ztsEndpoint,
+ URI ztsEndpoint,
X509Certificate certicate,
PrivateKey privateKey) {
- try (CloseableHttpClient client = createHttpClientWithTlsAuth(certicate, privateKey)) {
- String uri = String.format("%s/%s/%s/%s/%s",
- ztsEndpoint + INSTANCE_API_PATH,
- providerService, instanceDomain, instanceServiceName, instanceId);
+ try (CloseableHttpClient client = createHttpClientWithTlsAuth(certicate, privateKey, retryHandler)) {
+ URI uri = ztsEndpoint
+ .resolve(INSTANCE_API_PATH + '/')
+ .resolve(providerService + '/')
+ .resolve(instanceDomain + '/')
+ .resolve(instanceServiceName + '/')
+ .resolve(instanceId);
HttpUriRequest postRequest = RequestBuilder.post()
.setUri(uri)
.setEntity(toJsonStringEntity(instanceRefreshInformation))
@@ -92,7 +99,9 @@ public class AthenzService {
return new StringEntity(objectMapper.writeValueAsString(value), ContentType.APPLICATION_JSON);
}
- private static CloseableHttpClient createHttpClientWithTlsAuth(X509Certificate certificate, PrivateKey privateKey) {
+ private static CloseableHttpClient createHttpClientWithTlsAuth(X509Certificate certificate,
+ PrivateKey privateKey,
+ HttpRequestRetryHandler retryHandler) {
try {
String dummyPassword = "athenz";
KeyStore keyStore = KeyStore.getInstance("JKS");
@@ -102,6 +111,7 @@ public class AthenzService {
.loadKeyMaterial(keyStore, dummyPassword.toCharArray())
.build();
return HttpClientBuilder.create()
+ .setRetryHandler(retryHandler)
.setSslcontext(sslContext)
.build();
} catch (KeyStoreException | UnrecoverableKeyException | NoSuchAlgorithmException |
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java
index 1b109e4bacb..388b40a1fe0 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java
@@ -6,6 +6,9 @@ import org.bouncycastle.asn1.x509.Extension;
import org.bouncycastle.asn1.x509.ExtensionsGenerator;
import org.bouncycastle.asn1.x509.GeneralName;
import org.bouncycastle.asn1.x509.GeneralNames;
+import org.bouncycastle.cert.X509CertificateHolder;
+import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
import org.bouncycastle.operator.OperatorCreationException;
@@ -23,6 +26,7 @@ import java.io.UncheckedIOException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
+import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
/**
@@ -30,6 +34,8 @@ import java.security.cert.X509Certificate;
*/
class CryptoUtils {
+ private static final BouncyCastleProvider bouncyCastleProvider = new BouncyCastleProvider();
+
private CryptoUtils() {}
static KeyPair createKeyPair() {
@@ -45,7 +51,7 @@ class CryptoUtils {
String identityService,
String dnsSuffix,
String providerUniqueId,
- KeyPair keyPair) throws IOException {
+ KeyPair keyPair) {
try {
// Add SAN dnsname <service>.<domain-with-dashes>.<provider-dnsname-suffix>
// and SAN dnsname <provider-unique-instance-id>.instanceid.athenz.<provider-dnsname-suffix>
@@ -71,6 +77,8 @@ class CryptoUtils {
return requestBuilder.build(new JcaContentSignerBuilder("SHA256withRSA").build(keyPair.getPrivate()));
} catch (OperatorCreationException e) {
throw new RuntimeException(e);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
}
}
@@ -87,12 +95,19 @@ class CryptoUtils {
static X509Certificate parseCertificate(String pemEncodedCertificate) {
try (PEMParser parser = new PEMParser(new StringReader(pemEncodedCertificate))) {
Object pemObject = parser.readObject();
- if (!(pemObject instanceof X509Certificate)) {
- throw new IllegalArgumentException("Expeceted X509Certificate instance, got " + pemObject);
+ if (pemObject instanceof X509Certificate) {
+ return (X509Certificate) pemObject;
}
- return (X509Certificate) pemObject;
+ if (pemObject instanceof X509CertificateHolder) {
+ return new JcaX509CertificateConverter()
+ .setProvider(bouncyCastleProvider)
+ .getCertificate((X509CertificateHolder) pemObject);
+ }
+ throw new IllegalArgumentException("Invalid type of PEM object: " + pemObject);
} catch (IOException e) {
throw new UncheckedIOException(e);
+ } catch (CertificateException e) {
+ throw new RuntimeException(e);
}
}
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/ServiceProviderApi.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/IdentityDocumentService.java
index 6c1c22d07e0..542a5c739c8 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/ServiceProviderApi.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/IdentityDocumentService.java
@@ -3,8 +3,8 @@ package com.yahoo.container.jdisc.athenz.impl;
import com.yahoo.vespa.defaults.Defaults;
import org.apache.http.client.methods.CloseableHttpResponse;
-import org.apache.http.client.methods.HttpUriRequest;
-import org.apache.http.client.methods.RequestBuilder;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.client.utils.URIBuilder;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLContextBuilder;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
@@ -15,20 +15,21 @@ import org.eclipse.jetty.http.HttpStatus;
import java.io.IOException;
import java.net.URI;
-import java.net.URLEncoder;
+import java.net.URISyntaxException;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
/**
* @author mortent
+ * @author bjorncs
*/
-public class ServiceProviderApi {
+public class IdentityDocumentService {
- private final URI providerUri;
+ private final URI identityDocumentApiUri;
- public ServiceProviderApi(String providerAddress) {
- providerUri = URI.create(String.format("https://%s:8443/athenz/v1/provider", providerAddress));
+ public IdentityDocumentService(String loadBalancerName) {
+ this.identityDocumentApiUri = createIdentityDocumentApiUri(loadBalancerName);
}
/**
@@ -36,11 +37,7 @@ public class ServiceProviderApi {
*/
public String getSignedIdentityDocument() {
try (CloseableHttpClient httpClient = createHttpClient()) {
- // TODO Figure out a proper way of determining the hostname matching what's registred in node-repository
- String uri = providerUri + "/identity-document?hostname=" + URLEncoder.encode(
- Defaults.getDefaults().vespaHostname(), "UTF-8");
- HttpUriRequest request = RequestBuilder.get().setUri(uri).build();
- CloseableHttpResponse idDocResponse = httpClient.execute(request);
+ CloseableHttpResponse idDocResponse = httpClient.execute(new HttpGet(identityDocumentApiUri));
String responseContent = EntityUtils.toString(idDocResponse.getEntity());
if (HttpStatus.isSuccess(idDocResponse.getStatusLine().getStatusCode())) {
return responseContent;
@@ -70,4 +67,19 @@ public class ServiceProviderApi {
}
}
+ private static URI createIdentityDocumentApiUri(String loadBalancerName) {
+ try {
+ // TODO Figure out a proper way of determining the hostname matching what's registred in node-repository
+ return new URIBuilder()
+ .setScheme("https")
+ .setHost(loadBalancerName)
+ .setPort(4443)
+ .setPath("/athenz/v1/provider/identity-document")
+ .addParameter("hostname", Defaults.getDefaults().vespaHostname())
+ .build();
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceIdentity.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceIdentity.java
index ccb9b12c61a..20bbb2aa67e 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceIdentity.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceIdentity.java
@@ -4,8 +4,13 @@ package com.yahoo.container.jdisc.athenz.impl;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonDeserializer;
+import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
-import java.util.Map;
+import java.io.IOException;
+import java.security.cert.X509Certificate;
/**
* Used for deserializing response from ZTS
@@ -15,42 +20,29 @@ import java.util.Map;
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
public class InstanceIdentity {
- @JsonProperty("attributes") private final Map<String, String> attributes;
- @JsonProperty("provider") private final String provider;
- @JsonProperty("name") private final String name;
- @JsonProperty("instanceId") private final String instanceId;
- @JsonProperty("x509Certificate") private final String x509Certificate;
- @JsonProperty("x509CertificateSigner") private final String x509CertificateSigner;
- @JsonProperty("sshCertificate") private final String sshCertificate;
- @JsonProperty("sshCertificateSigner") private final String sshCertificateSigner;
+ @JsonProperty("x509Certificate") private final X509Certificate x509Certificate;
@JsonProperty("serviceToken") private final String serviceToken;
- public InstanceIdentity(
- @JsonProperty("attributes") Map<String, String> attributes,
- @JsonProperty("provider") String provider,
- @JsonProperty("name") String name,
- @JsonProperty("instanceId") String instanceId,
- @JsonProperty("x509Certificate") String x509Certificate,
- @JsonProperty("x509CertificateSigner") String x509CertificateSigner,
- @JsonProperty("sshCertificate") String sshCertificate,
- @JsonProperty("sshCertificateSigner") String sshCertificateSigner,
- @JsonProperty("serviceToken") String serviceToken) {
- this.attributes = attributes;
- this.provider = provider;
- this.name = name;
- this.instanceId = instanceId;
+ public InstanceIdentity(@JsonProperty("x509Certificate") @JsonDeserialize(using = X509CertificateDeserializer.class)
+ X509Certificate x509Certificate,
+ @JsonProperty("serviceToken") String serviceToken) {
this.x509Certificate = x509Certificate;
- this.x509CertificateSigner = x509CertificateSigner;
- this.sshCertificate = sshCertificate;
- this.sshCertificateSigner = sshCertificateSigner;
this.serviceToken = serviceToken;
}
- public String getX509Certificate() {
+ public X509Certificate getX509Certificate() {
return x509Certificate;
}
public String getServiceToken() {
return serviceToken;
}
+
+ public static class X509CertificateDeserializer extends JsonDeserializer<X509Certificate> {
+ @Override
+ public X509Certificate deserialize(JsonParser parser, DeserializationContext context) throws IOException {
+ return CryptoUtils.parseCertificate(parser.getValueAsString());
+ }
+ }
+
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRefreshInformation.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRefreshInformation.java
index 621eafca3bb..dd893cb3143 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRefreshInformation.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRefreshInformation.java
@@ -15,10 +15,9 @@ public class InstanceRefreshInformation {
@JsonProperty("csr")
private final String csr;
@JsonProperty("token")
- private final boolean requestServiceToken;
+ private final boolean requestServiceToken = true;
- public InstanceRefreshInformation(String csr, boolean requestServiceToken) {
+ public InstanceRefreshInformation(String csr) {
this.csr = csr;
- this.requestServiceToken = requestServiceToken;
}
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRegisterInformation.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRegisterInformation.java
index 61ab810abd5..e2355cb7a2d 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRegisterInformation.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/InstanceRegisterInformation.java
@@ -26,14 +26,13 @@ public class InstanceRegisterInformation {
@JsonProperty("csr")
private final String csr;
@JsonProperty("token")
- private final boolean token;
+ private final boolean token = true;
- public InstanceRegisterInformation(String provider, String domain, String service, String attestationData, String csr, boolean token) {
+ public InstanceRegisterInformation(String provider, String domain, String service, String attestationData, String csr) {
this.provider = provider;
this.domain = domain;
this.service = service;
this.attestationData = attestationData;
this.csr = csr;
- this.token = token;
}
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java
index d302b3d96ce..5d5b5430859 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java
@@ -5,21 +5,24 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.net.URI;
+
/**
* @author bjorncs
*/
+// TODO Most of these value should ideally be config provided by config-model
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
class SignedIdentityDocument {
public final String providerUniqueId;
public final String dnsSuffix;
public final String providerService;
- public final String ztsEndpoint;
+ public final URI ztsEndpoint;
public SignedIdentityDocument(@JsonProperty("provider-unique-id") String providerUniqueId,
@JsonProperty("dns-suffix") String dnsSuffix,
@JsonProperty("provider-service") String providerService,
- @JsonProperty("zts-endpoint") String ztsEndpoint) {
+ @JsonProperty("zts-endpoint") URI ztsEndpoint) {
this.providerUniqueId = providerUniqueId;
this.dnsSuffix = dnsSuffix;
this.providerService = providerService;
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java
index 1f64fb0d379..1c0efef2089 100644
--- a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java
+++ b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java
@@ -3,11 +3,30 @@ package com.yahoo.container.jdisc.athenz.impl;
import com.yahoo.container.core.identity.IdentityConfig;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
-import org.junit.Assert;
+import com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.RunnableWithTag;
+import com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.Scheduler;
+import com.yahoo.test.ManualClock;
import org.junit.Test;
-import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Predicate;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.INITIAL_BACKOFF_DELAY;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.INITIAL_WAIT_NTOKEN;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.MAX_REGISTER_BACKOFF_DELAY;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.REGISTER_INSTANCE_TAG;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.TIMEOUT_INITIAL_WAIT_TAG;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.UPDATE_PERIOD;
+import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
@@ -15,25 +34,112 @@ import static org.mockito.Mockito.when;
/**
* @author mortent
+ * @author bjorncs
*/
public class AthenzIdentityProviderImplTest {
+ private static final IdentityConfig IDENTITY_CONFIG =
+ new IdentityConfig(new IdentityConfig.Builder()
+ .service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg"));
+
+ @Test
+ public void athenz_credentials_are_retrieved_after_component_contruction_completed() {
+ IdentityDocumentService identityDocumentService = mock(IdentityDocumentService.class);
+ AthenzService athenzService = mock(AthenzService.class);
+ ManualClock clock = new ManualClock(Instant.EPOCH);
+ MockScheduler scheduler = new MockScheduler(clock);
+
+ when(identityDocumentService.getSignedIdentityDocument()).thenReturn(getIdentityDocument());
+ when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn(
+ new InstanceIdentity(null, "TOKEN"));
+ AthenzCredentialsService credentialService =
+ new AthenzCredentialsService(IDENTITY_CONFIG, identityDocumentService, athenzService, clock);
+
+ AthenzIdentityProvider identityProvider =
+ new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, clock);
+
+ List<MockScheduler.CompletedTask> expectedTasks =
+ Arrays.asList(
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO),
+ new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN));
+ // Don't run update credential tasks, otherwise infinite loop
+ List<MockScheduler.CompletedTask> completedTasks =
+ scheduler.runAllTasks(task -> !task.tag().equals(UPDATE_CREDENTIALS_TAG));
+ assertEquals(expectedTasks, completedTasks);
+ assertEquals("TOKEN", identityProvider.getNToken());
+ }
+
@Test
- public void ntoken_fetched_on_init() throws IOException {
- IdentityConfig config = new IdentityConfig(new IdentityConfig.Builder().service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg"));
- ServiceProviderApi serviceProviderApi = mock(ServiceProviderApi.class);
+ public void register_instance_uses_exponential_backoff() {
+ AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class);
+ when(credentialService.registerInstance())
+ .thenThrow(new RuntimeException("#1"))
+ .thenThrow(new RuntimeException("#2"))
+ .thenThrow(new RuntimeException("#3"))
+ .thenThrow(new RuntimeException("#4"))
+ .thenThrow(new RuntimeException("#5"))
+ .thenReturn(new AthenzCredentials("TOKEN", null, null, null, Instant.now()));
+
+ ManualClock clock = new ManualClock(Instant.EPOCH);
+ MockScheduler scheduler = new MockScheduler(clock);
+ AthenzIdentityProvider identityProvider =
+ new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, clock);
+
+ List<MockScheduler.CompletedTask> expectedTasks =
+ Arrays.asList(
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY),
+ new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(2)),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(4)),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(8)),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, MAX_REGISTER_BACKOFF_DELAY));
+ // Don't run update credential tasks, otherwise infinite loop
+ List<MockScheduler.CompletedTask> completedTasks =
+ scheduler.runAllTasks(task -> !task.tag().equals(UPDATE_CREDENTIALS_TAG));
+ assertEquals(expectedTasks, completedTasks);
+ assertEquals("TOKEN", identityProvider.getNToken());
+ }
+
+ @Test
+ public void failed_credentials_updates_will_schedule_retries() {
+ IdentityDocumentService identityDocumentService = mock(IdentityDocumentService.class);
AthenzService athenzService = mock(AthenzService.class);
+ ManualClock clock = new ManualClock(Instant.EPOCH);
+ MockScheduler scheduler = new MockScheduler(clock);
- when(serviceProviderApi.getSignedIdentityDocument()).thenReturn(getIdentityDocument());
- when(athenzService.sendInstanceRegisterRequest(any(), anyString())).thenReturn(
- new InstanceIdentity(null, null, null, null, null, null, null, null, "TOKEN"));
+ when(identityDocumentService.getSignedIdentityDocument()).thenReturn(getIdentityDocument());
+ when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn(
+ new InstanceIdentity(null, "TOKEN"));
+ when(athenzService.sendInstanceRefreshRequest(anyString(), anyString(), anyString(),
+ anyString(), any(), any(), any(), any()))
+ .thenThrow(new RuntimeException("#1"))
+ .thenThrow(new RuntimeException("#2"))
+ .thenThrow(new RuntimeException("#3"))
+ .thenReturn(new InstanceIdentity(null, "TOKEN"));
+ AthenzCredentialsService credentialService =
+ new AthenzCredentialsService(IDENTITY_CONFIG, identityDocumentService, athenzService, clock);
- AthenzIdentityProvider identityProvider = new AthenzIdentityProviderImpl(config, serviceProviderApi, athenzService);
+ AthenzIdentityProvider identityProvider =
+ new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, clock);
- Assert.assertEquals("TOKEN", identityProvider.getNToken());
+ List<MockScheduler.CompletedTask> expectedTasks =
+ Arrays.asList(
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO),
+ new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD));
+ AtomicInteger counter = new AtomicInteger(0);
+ List<MockScheduler.CompletedTask> completedTasks =
+ scheduler.runAllTasks(task -> counter.getAndIncrement() < 7); // 1 registration + 1 timeout + 5 update tasks
+ assertEquals(expectedTasks, completedTasks);
+ assertEquals("TOKEN", identityProvider.getNToken());
}
- private String getIdentityDocument() {
+ private static String getIdentityDocument() {
return "{\n" +
" \"identity-document\": \"eyJwcm92aWRlci11bmlxdWUtaWQiOnsidGVuYW50IjoidGVuYW50IiwiYXBwbGljYXRpb24iOiJhcHBsaWNhdGlvbiIsImVudmlyb25tZW50IjoiZGV2IiwicmVnaW9uIjoidXMtbm9ydGgtMSIsImluc3RhbmNlIjoiZGVmYXVsdCIsImNsdXN0ZXItaWQiOiJkZWZhdWx0IiwiY2x1c3Rlci1pbmRleCI6MH0sImNvbmZpZ3NlcnZlci1ob3N0bmFtZSI6ImxvY2FsaG9zdCIsImluc3RhbmNlLWhvc3RuYW1lIjoieC55LmNvbSIsImNyZWF0ZWQtYXQiOjE1MDg3NDgyODUuNzQyMDAwMDAwfQ==\",\n" +
" \"signature\": \"kkEJB/98cy1FeXxzSjtvGH2a6BFgZu/9/kzCcAqRMZjENxnw5jyO1/bjZVzw2Sz4YHPsWSx2uxb32hiQ0U8rMP0zfA9nERIalSP0jB/hMU8laezGhdpk6VKZPJRC6YKAB9Bsv2qUIfMsSxkMqf66GUvjZAGaYsnNa2yHc1jIYHOGMeJO+HNPYJjGv26xPfAOPIKQzs3RmKrc3FoweTCsIwm5oblqekdJvVWYe0obwlOSB5uwc1zpq3Ie1QBFtJRuCGMVHg1pDPxXKBHLClGIrEvzLmICy6IRdHszSO5qiwujUD7sbrbM0sB/u0cYucxbcsGRUmBvme3UAw2mW9POVQ==\",\n" +
@@ -46,4 +152,82 @@ public class AthenzIdentityProviderImplTest {
"}";
}
+
+ private static class MockScheduler implements Scheduler {
+
+ private final PriorityQueue<DelayedTask> tasks = new PriorityQueue<>();
+ private final ManualClock clock;
+
+ MockScheduler(ManualClock clock) {
+ this.clock = clock;
+ }
+
+ @Override
+ public void schedule(RunnableWithTag task, Duration delay) {
+ tasks.offer(new DelayedTask(task, delay, clock.instant().plus(delay)));
+ }
+
+ List<CompletedTask> runAllTasks(Predicate<RunnableWithTag> filter) {
+ List<CompletedTask> completedTasks = new ArrayList<>();
+ while (!tasks.isEmpty()) {
+ DelayedTask task = tasks.poll();
+ RunnableWithTag runnable = task.runnableWithTag;
+ if (filter.test(runnable)) {
+ clock.setInstant(task.startTime);
+ runnable.run();
+ completedTasks.add(new CompletedTask(runnable.tag(), task.delay));
+ }
+ }
+ return completedTasks;
+ }
+
+ private static class DelayedTask implements Comparable<DelayedTask> {
+ final RunnableWithTag runnableWithTag;
+ final Duration delay;
+ final Instant startTime;
+
+ DelayedTask(RunnableWithTag runnableWithTag, Duration delay, Instant startTime) {
+ this.runnableWithTag = runnableWithTag;
+ this.delay = delay;
+ this.startTime = startTime;
+ }
+
+ @Override
+ public int compareTo(DelayedTask other) {
+ return this.startTime.compareTo(other.startTime);
+ }
+ }
+
+ private static class CompletedTask {
+ final String tag;
+ final Duration delay;
+
+ CompletedTask(String tag, Duration delay) {
+ this.tag = tag;
+ this.delay = delay;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CompletedTask that = (CompletedTask) o;
+ return Objects.equals(tag, that.tag) &&
+ Objects.equals(delay, that.delay);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(tag, delay);
+ }
+
+ @Override
+ public String toString() {
+ return "CompletedTask{" +
+ "tag='" + tag + '\'' +
+ ", delay=" + delay +
+ '}';
+ }
+ }
+ }
}