summaryrefslogtreecommitdiffstats
path: root/container-disc
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@oath.com>2017-11-06 17:45:32 +0100
committerBjørn Christian Seime <bjorncs@oath.com>2017-11-06 17:49:17 +0100
commitc14153031e48b71bf4b7c66fab88cb37b8d49788 (patch)
treebb326f9e6c8ff7c01625c598c496f77d1f7ee2b3 /container-disc
parent145d7af877637a1c1bd024cc9fedf6de644ba584 (diff)
Add token refresh to AthenzIdentityProvider implementation
Includes logic for retry: exponential backoff for instance registration, linear backoff for credential updates. Moved instance registration + credentials update logic to new class AthenzCredentialsService.
Diffstat (limited to 'container-disc')
-rw-r--r--container-disc/pom.xml6
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java10
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java16
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java51
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java95
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java231
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java4
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java6
-rw-r--r--container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java199
9 files changed, 564 insertions, 54 deletions
diff --git a/container-disc/pom.xml b/container-disc/pom.xml
index bd8a3340622..952db36367b 100644
--- a/container-disc/pom.xml
+++ b/container-disc/pom.xml
@@ -20,6 +20,12 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>testutil</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
index 418b3511ebb..033b396bc9b 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
@@ -1,15 +1,11 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc.athenz;
-import java.security.cert.X509Certificate;
-
/**
* @author mortent
*/
public interface AthenzIdentityProvider {
-
- String getNToken();
- X509Certificate getX509Cert();
- String domain();
- String service();
+ String getNToken() throws AthenzIdentityProviderException;
+ String getDomain();
+ String getService();
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java
new file mode 100644
index 00000000000..fd5839bfc45
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProviderException.java
@@ -0,0 +1,16 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.jdisc.athenz;
+
+/**
+ * @author bjorncs
+ */
+public class AthenzIdentityProviderException extends RuntimeException {
+
+ public AthenzIdentityProviderException(String message) {
+ super(message);
+ }
+
+ public AthenzIdentityProviderException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java
new file mode 100644
index 00000000000..790a7c54333
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentials.java
@@ -0,0 +1,51 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.jdisc.athenz.impl;
+
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.time.Instant;
+
+/**
+ * @author bjorncs
+ */
+class AthenzCredentials {
+
+ private final String nToken;
+ private final X509Certificate certificate;
+ private final KeyPair keyPair;
+ private final SignedIdentityDocument identityDocument;
+ private final Instant createdAt;
+
+ AthenzCredentials(String nToken,
+ X509Certificate certificate,
+ KeyPair keyPair,
+ SignedIdentityDocument identityDocument,
+ Instant createdAt) {
+ this.nToken = nToken;
+ this.certificate = certificate;
+ this.keyPair = keyPair;
+ this.identityDocument = identityDocument;
+ this.createdAt = createdAt;
+ }
+
+ String getNToken() {
+ return nToken;
+ }
+
+ X509Certificate getCertificate() {
+ return certificate;
+ }
+
+ KeyPair getKeyPair() {
+ return keyPair;
+ }
+
+ SignedIdentityDocument getIdentityDocument() {
+ return identityDocument;
+ }
+
+ Instant getCreatedAt() {
+ return createdAt;
+ }
+
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java
new file mode 100644
index 00000000000..ea9e50cbb95
--- /dev/null
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzCredentialsService.java
@@ -0,0 +1,95 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.jdisc.athenz.impl;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.container.core.identity.IdentityConfig;
+import org.bouncycastle.pkcs.PKCS10CertificationRequest;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.time.Clock;
+
+/**
+ * @author bjorncs
+ */
+class AthenzCredentialsService {
+
+ private static final ObjectMapper mapper = new ObjectMapper();
+
+ private final IdentityConfig identityConfig;
+ private final ServiceProviderApi serviceProviderApi;
+ private final AthenzService athenzService;
+ private final Clock clock;
+
+ AthenzCredentialsService(IdentityConfig identityConfig,
+ ServiceProviderApi serviceProviderApi,
+ AthenzService athenzService,
+ Clock clock) {
+ this.identityConfig = identityConfig;
+ this.serviceProviderApi = serviceProviderApi;
+ this.athenzService = athenzService;
+ this.clock = clock;
+ }
+
+ AthenzCredentials registerInstance() {
+ KeyPair keyPair = CryptoUtils.createKeyPair();
+ String rawDocument = serviceProviderApi.getSignedIdentityDocument();
+ SignedIdentityDocument document = parseSignedIdentityDocument(rawDocument);
+ PKCS10CertificationRequest csr = CryptoUtils.createCSR(identityConfig.domain(),
+ identityConfig.service(),
+ document.dnsSuffix,
+ document.providerUniqueId,
+ keyPair);
+ InstanceRegisterInformation instanceRegisterInformation =
+ new InstanceRegisterInformation(document.providerService,
+ identityConfig.domain(),
+ identityConfig.service(),
+ rawDocument,
+ CryptoUtils.toPem(csr),
+ true);
+ InstanceIdentity instanceIdentity = athenzService.sendInstanceRegisterRequest(instanceRegisterInformation,
+ document.ztsEndpoint);
+ return toAthenzCredentials(instanceIdentity, keyPair, document);
+ }
+
+ AthenzCredentials updateCredentials(AthenzCredentials currentCredentials) {
+ SignedIdentityDocument document = currentCredentials.getIdentityDocument();
+ KeyPair newKeyPair = CryptoUtils.createKeyPair();
+ PKCS10CertificationRequest csr = CryptoUtils.createCSR(identityConfig.domain(),
+ identityConfig.service(),
+ document.dnsSuffix,
+ document.providerUniqueId,
+ newKeyPair);
+ InstanceRefreshInformation refreshInfo = new InstanceRefreshInformation(CryptoUtils.toPem(csr),
+ /*requestServiceToken*/true);
+ InstanceIdentity instanceIdentity =
+ athenzService.sendInstanceRefreshRequest(document.providerService,
+ identityConfig.domain(),
+ identityConfig.service(),
+ document.providerUniqueId,
+ refreshInfo,
+ document.ztsEndpoint,
+ currentCredentials.getCertificate(),
+ currentCredentials.getKeyPair().getPrivate());
+ return toAthenzCredentials(instanceIdentity, newKeyPair, document);
+ }
+
+ private AthenzCredentials toAthenzCredentials(InstanceIdentity instanceIdentity,
+ KeyPair keyPair,
+ SignedIdentityDocument identityDocument) {
+ X509Certificate certificate = instanceIdentity.getX509Certificate();
+ String serviceToken = instanceIdentity.getServiceToken();
+ return new AthenzCredentials(serviceToken, certificate, keyPair, identityDocument, clock.instant());
+ }
+
+ private static SignedIdentityDocument parseSignedIdentityDocument(String rawDocument) {
+ try {
+ return mapper.readValue(rawDocument, SignedIdentityDocument.class);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
index 83c001eaab7..2f98d852a95 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImpl.java
@@ -1,78 +1,237 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc.athenz.impl;
-import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.container.core.identity.IdentityConfig;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
+import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
+import com.yahoo.log.LogLevel;
-import java.io.IOException;
-import java.net.URI;
-import java.security.KeyPair;
-import java.security.cert.X509Certificate;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.logging.Logger;
/**
* @author mortent
+ * @author bjorncs
*/
public final class AthenzIdentityProviderImpl extends AbstractComponent implements AthenzIdentityProvider {
- private final ObjectMapper objectMapper = new ObjectMapper();
+ private static final Logger log = Logger.getLogger(AthenzIdentityProviderImpl.class.getName());
- private InstanceIdentity instanceIdentity;
+ // TODO Make some of these values configurable through config. Match requested expiration of register/update requests.
+ // TODO These should match the requested expiration
+ static final Duration EXPIRES_AFTER = Duration.ofDays(1);
+ static final Duration EXPIRATION_MARGIN = Duration.ofMinutes(30);
+ static final Duration INITIAL_WAIT_NTOKEN = Duration.ofMinutes(5);
+ static final Duration UPDATE_PERIOD = EXPIRES_AFTER.dividedBy(3);
+ static final Duration REDUCED_UPDATE_PERIOD = Duration.ofMinutes(30);
+ static final Duration INITIAL_BACKOFF_DELAY = Duration.ofMinutes(4);
+ static final Duration MAX_REGISTER_BACKOFF_DELAY = Duration.ofHours(1);
+ static final int BACKOFF_DELAY_MULTIPLIER = 2;
+ static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90);
- private final String dnsSuffix;
- private final String providerUniqueId;
+
+ static final String REGISTER_INSTANCE_TAG = "register-instance";
+ static final String UPDATE_CREDENTIALS_TAG = "update-credentials";
+ static final String TIMEOUT_INITIAL_WAIT_TAG = "timeout-initial-wait";
+
+
+ private final AtomicReference<AthenzCredentials> credentials = new AtomicReference<>();
+ private final AtomicReference<Throwable> lastThrowable = new AtomicReference<>();
+ private final CountDownLatch credentialsRetrievedSignal = new CountDownLatch(1);
+ private final AthenzCredentialsService athenzCredentialsService;
+ private final Scheduler scheduler;
+ private final Clock clock;
private final String domain;
private final String service;
@Inject
- public AthenzIdentityProviderImpl(IdentityConfig config) throws IOException {
- this(config, new ServiceProviderApi(config.loadBalancerAddress()), new AthenzService());
+ public AthenzIdentityProviderImpl(IdentityConfig config) {
+ this(config,
+ new AthenzCredentialsService(config,
+ new ServiceProviderApi(config.loadBalancerAddress()),
+ new AthenzService(),
+ Clock.systemUTC()),
+ new ThreadPoolScheduler(),
+ Clock.systemUTC());
}
// Test only
AthenzIdentityProviderImpl(IdentityConfig config,
- ServiceProviderApi serviceProviderApi,
- AthenzService athenzService) throws IOException {
- KeyPair keyPair = CryptoUtils.createKeyPair();
+ AthenzCredentialsService athenzCredentialsService,
+ Scheduler scheduler,
+ Clock clock) {
+ this.athenzCredentialsService = athenzCredentialsService;
+ this.scheduler = scheduler;
+ this.clock = clock;
this.domain = config.domain();
this.service = config.service();
- String rawDocument = serviceProviderApi.getSignedIdentityDocument();
- SignedIdentityDocument document = objectMapper.readValue(rawDocument, SignedIdentityDocument.class);
- this.dnsSuffix = document.dnsSuffix;
- this.providerUniqueId = document.providerUniqueId;
-
- InstanceRegisterInformation instanceRegisterInformation = new InstanceRegisterInformation(
- document.providerService,
- this.domain,
- this.service,
- rawDocument,
- CryptoUtils.toPem(CryptoUtils.createCSR(domain, service, dnsSuffix, providerUniqueId, keyPair)),
- true
- );
- instanceIdentity = athenzService.sendInstanceRegisterRequest(instanceRegisterInformation,
- URI.create(document.ztsEndpoint));
+ scheduler.submit(new RegisterInstanceTask());
+ scheduler.schedule(new TimeoutInitialWaitTask(), INITIAL_WAIT_NTOKEN);
}
@Override
public String getNToken() {
- return instanceIdentity.getServiceToken();
+ try {
+ credentialsRetrievedSignal.await();
+ AthenzCredentials credentialsSnapshot = credentials.get();
+ if (credentialsSnapshot == null) {
+ throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", lastThrowable.get());
+ }
+ if (isExpired(credentialsSnapshot)) {
+ throw new AthenzIdentityProviderException("Athenz credentials are expired", lastThrowable.get());
+ }
+ return credentialsSnapshot.getNToken();
+ } catch (InterruptedException e) {
+ throw new AthenzIdentityProviderException("Failed to register instance credentials", lastThrowable.get());
+ }
}
@Override
- public X509Certificate getX509Cert() {
- return instanceIdentity.getX509Certificate();
+ public String getDomain() {
+ return domain;
}
@Override
- public String domain() {
- return domain;
+ public String getService() {
+ return service;
}
@Override
- public String service() {
- return service;
+ public void deconstruct() {
+ scheduler.shutdown(AWAIT_TERMINTATION_TIMEOUT);
+ }
+
+ private boolean isExpired(AthenzCredentials credentials) {
+ return clock.instant().isAfter(getExpirationTime(credentials));
+ }
+
+ private static Instant getExpirationTime(AthenzCredentials credentials) {
+ return credentials.getCreatedAt().plus(EXPIRES_AFTER).minus(EXPIRATION_MARGIN);
+ }
+
+ private class RegisterInstanceTask implements RunnableWithTag {
+
+ private final Duration backoffDelay;
+
+ RegisterInstanceTask() {
+ this(INITIAL_BACKOFF_DELAY);
+ }
+
+ RegisterInstanceTask(Duration backoffDelay) {
+ this.backoffDelay = backoffDelay;
+ }
+
+ @Override
+ public void run() {
+ try {
+ credentials.set(athenzCredentialsService.registerInstance());
+ credentialsRetrievedSignal.countDown();
+ scheduler.schedule(new UpdateCredentialsTask(), UPDATE_PERIOD);
+ } catch (Throwable t) {
+ log.log(LogLevel.ERROR, "Failed to register instance: " + t.getMessage(), t);
+ lastThrowable.set(t);
+ Duration nextBackoffDelay = backoffDelay.multipliedBy(BACKOFF_DELAY_MULTIPLIER);
+ if (nextBackoffDelay.compareTo(MAX_REGISTER_BACKOFF_DELAY) > 0) {
+ nextBackoffDelay = MAX_REGISTER_BACKOFF_DELAY;
+ }
+ scheduler.schedule(new RegisterInstanceTask(nextBackoffDelay), backoffDelay);
+ }
+ }
+
+ @Override
+ public String tag() {
+ return REGISTER_INSTANCE_TAG;
+ }
+ }
+
+ private class UpdateCredentialsTask implements RunnableWithTag {
+ @Override
+ public void run() {
+ AthenzCredentials currentCredentials = credentials.get();
+ try {
+ AthenzCredentials newCredentials = isExpired(currentCredentials)
+ ? athenzCredentialsService.registerInstance()
+ : athenzCredentialsService.updateCredentials(currentCredentials);
+ credentials.set(newCredentials);
+ scheduler.schedule(new UpdateCredentialsTask(), UPDATE_PERIOD);
+ } catch (Throwable t) {
+ log.log(LogLevel.ERROR, "Failed to update credentials: " + t.getMessage(), t);
+ lastThrowable.set(t);
+ Duration timeToExpiration = Duration.between(clock.instant(), getExpirationTime(currentCredentials));
+ // NOTE: Update period might be after timeToExpiration, still we do not want to DDoS Athenz.
+ Duration updatePeriod =
+ timeToExpiration.compareTo(UPDATE_PERIOD) > 0 ? UPDATE_PERIOD : REDUCED_UPDATE_PERIOD;
+ scheduler.schedule(new UpdateCredentialsTask(), updatePeriod);
+ }
+ }
+
+ @Override
+ public String tag() {
+ return UPDATE_CREDENTIALS_TAG;
+ }
+ }
+
+ private class TimeoutInitialWaitTask implements RunnableWithTag {
+ @Override
+ public void run() {
+ credentialsRetrievedSignal.countDown();
+ }
+
+ @Override
+ public String tag() {
+ return TIMEOUT_INITIAL_WAIT_TAG;
+ }
}
+
+ private static class ThreadPoolScheduler implements Scheduler {
+
+ private static final Logger log = Logger.getLogger(ThreadPoolScheduler.class.getName());
+
+ private final ScheduledExecutorService executor = Executors.newScheduledThreadPool(0);
+
+ @Override
+ public void schedule(RunnableWithTag runnable, Duration delay) {
+ log.log(LogLevel.FINE, String.format("Scheduling task '%s' in '%s'", runnable.tag(), delay));
+ executor.schedule(runnable, delay.getSeconds(), TimeUnit.SECONDS);
+ }
+
+ @Override
+ public void submit(RunnableWithTag runnable) {
+ log.log(LogLevel.FINE, String.format("Scheduling task '%s' now", runnable.tag()));
+ executor.submit(runnable);
+ }
+
+ @Override
+ public void shutdown(Duration timeout) {
+ try {
+ executor.shutdownNow();
+ executor.awaitTermination(AWAIT_TERMINTATION_TIMEOUT.getSeconds(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ }
+
+ public interface Scheduler {
+ void schedule(RunnableWithTag runnable, Duration delay);
+ default void submit(RunnableWithTag runnable) { schedule(runnable, Duration.ZERO); }
+ default void shutdown(Duration timeout) {}
+ }
+
+ public interface RunnableWithTag extends Runnable {
+
+ String tag();
+ }
+
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java
index 1b109e4bacb..6ff7857df4a 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/CryptoUtils.java
@@ -45,7 +45,7 @@ class CryptoUtils {
String identityService,
String dnsSuffix,
String providerUniqueId,
- KeyPair keyPair) throws IOException {
+ KeyPair keyPair) {
try {
// Add SAN dnsname <service>.<domain-with-dashes>.<provider-dnsname-suffix>
// and SAN dnsname <provider-unique-instance-id>.instanceid.athenz.<provider-dnsname-suffix>
@@ -71,6 +71,8 @@ class CryptoUtils {
return requestBuilder.build(new JcaContentSignerBuilder("SHA256withRSA").build(keyPair.getPrivate()));
} catch (OperatorCreationException e) {
throw new RuntimeException(e);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
}
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java
index d302b3d96ce..d9b9bdd5c0d 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/impl/SignedIdentityDocument.java
@@ -5,6 +5,8 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.net.URI;
+
/**
* @author bjorncs
*/
@@ -14,12 +16,12 @@ class SignedIdentityDocument {
public final String providerUniqueId;
public final String dnsSuffix;
public final String providerService;
- public final String ztsEndpoint;
+ public final URI ztsEndpoint;
public SignedIdentityDocument(@JsonProperty("provider-unique-id") String providerUniqueId,
@JsonProperty("dns-suffix") String dnsSuffix,
@JsonProperty("provider-service") String providerService,
- @JsonProperty("zts-endpoint") String ztsEndpoint) {
+ @JsonProperty("zts-endpoint") URI ztsEndpoint) {
this.providerUniqueId = providerUniqueId;
this.dnsSuffix = dnsSuffix;
this.providerService = providerService;
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java
index b1c7699364d..d13d86553c6 100644
--- a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java
+++ b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java
@@ -3,35 +3,140 @@ package com.yahoo.container.jdisc.athenz.impl;
import com.yahoo.container.core.identity.IdentityConfig;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
-import org.junit.Assert;
+import com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.RunnableWithTag;
+import com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.Scheduler;
+import com.yahoo.test.ManualClock;
import org.junit.Test;
-import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Predicate;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.INITIAL_BACKOFF_DELAY;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.INITIAL_WAIT_NTOKEN;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.MAX_REGISTER_BACKOFF_DELAY;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.REGISTER_INSTANCE_TAG;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.TIMEOUT_INITIAL_WAIT_TAG;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG;
+import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.UPDATE_PERIOD;
+import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* @author mortent
+ * @author bjorncs
*/
public class AthenzIdentityProviderImplTest {
+ private static final IdentityConfig IDENTITY_CONFIG =
+ new IdentityConfig(new IdentityConfig.Builder()
+ .service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg"));
+
+ @Test
+ public void athenz_credentials_are_retrieved_after_component_contruction_completed() {
+ ServiceProviderApi serviceProviderApi = mock(ServiceProviderApi.class);
+ AthenzService athenzService = mock(AthenzService.class);
+ MockScheduler scheduler = new MockScheduler();
+
+ when(serviceProviderApi.getSignedIdentityDocument()).thenReturn(getIdentityDocument());
+ when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn(
+ new InstanceIdentity(null, "TOKEN"));
+ AthenzCredentialsService credentialService =
+ new AthenzCredentialsService(IDENTITY_CONFIG, serviceProviderApi, athenzService, scheduler.clock());
+
+ AthenzIdentityProvider identityProvider =
+ new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, scheduler.clock());
+
+ List<MockScheduler.CompletedTask> expectedTasks =
+ Arrays.asList(
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO),
+ new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN));
+ // Don't run update credential tasks, otherwise infinite loop
+ List<MockScheduler.CompletedTask> completedTasks =
+ scheduler.runAllTasks(task -> !task.tag().equals(UPDATE_CREDENTIALS_TAG));
+ assertEquals(expectedTasks, completedTasks);
+ assertEquals("TOKEN", identityProvider.getNToken());
+ }
+
+ @Test
+ public void register_instance_uses_exponential_backoff() {
+ AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class);
+ when(credentialService.registerInstance())
+ .thenThrow(new RuntimeException("#1"))
+ .thenThrow(new RuntimeException("#2"))
+ .thenThrow(new RuntimeException("#3"))
+ .thenThrow(new RuntimeException("#4"))
+ .thenThrow(new RuntimeException("#5"))
+ .thenReturn(new AthenzCredentials("TOKEN", null, null, null, Instant.now()));
+
+ MockScheduler scheduler = new MockScheduler();
+ AthenzIdentityProvider identityProvider =
+ new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, scheduler.clock());
+
+ List<MockScheduler.CompletedTask> expectedTasks =
+ Arrays.asList(
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY),
+ new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(2)),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(4)),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(8)),
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, MAX_REGISTER_BACKOFF_DELAY));
+ // Don't run update credential tasks, otherwise infinite loop
+ List<MockScheduler.CompletedTask> completedTasks =
+ scheduler.runAllTasks(task -> !task.tag().equals(UPDATE_CREDENTIALS_TAG));
+ assertEquals(expectedTasks, completedTasks);
+ assertEquals("TOKEN", identityProvider.getNToken());
+ }
+
@Test
- public void ntoken_fetched_on_init() throws IOException {
- IdentityConfig config = new IdentityConfig(new IdentityConfig.Builder().service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg"));
+ public void failed_credentials_updates_will_schedule_retries() {
ServiceProviderApi serviceProviderApi = mock(ServiceProviderApi.class);
AthenzService athenzService = mock(AthenzService.class);
+ MockScheduler scheduler = new MockScheduler();
when(serviceProviderApi.getSignedIdentityDocument()).thenReturn(getIdentityDocument());
- when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn(new InstanceIdentity(null, "TOKEN"));
+ when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn(
+ new InstanceIdentity(null, "TOKEN"));
+ when(athenzService.sendInstanceRefreshRequest(anyString(), anyString(), anyString(),
+ anyString(), any(), any(), any(), any()))
+ .thenThrow(new RuntimeException("#1"))
+ .thenThrow(new RuntimeException("#2"))
+ .thenThrow(new RuntimeException("#3"))
+ .thenReturn(new InstanceIdentity(null, "TOKEN"));
+ AthenzCredentialsService credentialService =
+ new AthenzCredentialsService(IDENTITY_CONFIG, serviceProviderApi, athenzService, scheduler.clock());
- AthenzIdentityProvider identityProvider = new AthenzIdentityProviderImpl(config, serviceProviderApi, athenzService);
+ AthenzIdentityProvider identityProvider =
+ new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, scheduler.clock());
- Assert.assertEquals("TOKEN", identityProvider.getNToken());
+ List<MockScheduler.CompletedTask> expectedTasks =
+ Arrays.asList(
+ new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO),
+ new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD),
+ new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD));
+ AtomicInteger counter = new AtomicInteger(0);
+ List<MockScheduler.CompletedTask> completedTasks =
+ scheduler.runAllTasks(task -> counter.getAndIncrement() < 7); // 1 registration + 1 timeout + 5 update tasks
+ assertEquals(expectedTasks, completedTasks);
+ assertEquals("TOKEN", identityProvider.getNToken());
}
- private String getIdentityDocument() {
+ private static String getIdentityDocument() {
return "{\n" +
" \"identity-document\": \"eyJwcm92aWRlci11bmlxdWUtaWQiOnsidGVuYW50IjoidGVuYW50IiwiYXBwbGljYXRpb24iOiJhcHBsaWNhdGlvbiIsImVudmlyb25tZW50IjoiZGV2IiwicmVnaW9uIjoidXMtbm9ydGgtMSIsImluc3RhbmNlIjoiZGVmYXVsdCIsImNsdXN0ZXItaWQiOiJkZWZhdWx0IiwiY2x1c3Rlci1pbmRleCI6MH0sImNvbmZpZ3NlcnZlci1ob3N0bmFtZSI6ImxvY2FsaG9zdCIsImluc3RhbmNlLWhvc3RuYW1lIjoieC55LmNvbSIsImNyZWF0ZWQtYXQiOjE1MDg3NDgyODUuNzQyMDAwMDAwfQ==\",\n" +
" \"signature\": \"kkEJB/98cy1FeXxzSjtvGH2a6BFgZu/9/kzCcAqRMZjENxnw5jyO1/bjZVzw2Sz4YHPsWSx2uxb32hiQ0U8rMP0zfA9nERIalSP0jB/hMU8laezGhdpk6VKZPJRC6YKAB9Bsv2qUIfMsSxkMqf66GUvjZAGaYsnNa2yHc1jIYHOGMeJO+HNPYJjGv26xPfAOPIKQzs3RmKrc3FoweTCsIwm5oblqekdJvVWYe0obwlOSB5uwc1zpq3Ie1QBFtJRuCGMVHg1pDPxXKBHLClGIrEvzLmICy6IRdHszSO5qiwujUD7sbrbM0sB/u0cYucxbcsGRUmBvme3UAw2mW9POVQ==\",\n" +
@@ -44,4 +149,82 @@ public class AthenzIdentityProviderImplTest {
"}";
}
+
+ private static class MockScheduler implements Scheduler {
+
+ private final PriorityQueue<DelayedTask> tasks = new PriorityQueue<>();
+ private final ManualClock clock = new ManualClock(Instant.EPOCH);
+
+ @Override
+ public void schedule(RunnableWithTag task, Duration delay) {
+ tasks.offer(new DelayedTask(task, delay, clock.instant().plus(delay)));
+ }
+
+ List<CompletedTask> runAllTasks(Predicate<RunnableWithTag> filter) {
+ List<CompletedTask> completedTasks = new ArrayList<>();
+ while (!tasks.isEmpty()) {
+ DelayedTask task = tasks.poll();
+ RunnableWithTag runnable = task.runnableWithTag;
+ if (filter.test(runnable)) {
+ clock.setInstant(task.startTime);
+ runnable.run();
+ completedTasks.add(new CompletedTask(runnable.tag(), task.delay));
+ }
+ }
+ return completedTasks;
+ }
+
+ public ManualClock clock() {
+ return clock;
+ }
+
+ private static class DelayedTask implements Comparable<DelayedTask> {
+ final RunnableWithTag runnableWithTag;
+ final Duration delay;
+ final Instant startTime;
+
+ DelayedTask(RunnableWithTag runnableWithTag, Duration delay, Instant startTime) {
+ this.runnableWithTag = runnableWithTag;
+ this.delay = delay;
+ this.startTime = startTime;
+ }
+
+ @Override
+ public int compareTo(DelayedTask other) {
+ return this.startTime.compareTo(other.startTime);
+ }
+ }
+
+ private static class CompletedTask {
+ final String tag;
+ final Duration delay;
+
+ CompletedTask(String tag, Duration delay) {
+ this.tag = tag;
+ this.delay = delay;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ CompletedTask that = (CompletedTask) o;
+ return Objects.equals(tag, that.tag) &&
+ Objects.equals(delay, that.delay);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(tag, delay);
+ }
+
+ @Override
+ public String toString() {
+ return "CompletedTask{" +
+ "tag='" + tag + '\'' +
+ ", delay=" + delay +
+ '}';
+ }
+ }
+ }
}