diff options
author | Morten Tokle <mortent@oath.com> | 2018-03-01 11:20:35 +0100 |
---|---|---|
committer | Morten Tokle <mortent@oath.com> | 2018-03-01 11:20:35 +0100 |
commit | bacea4ade2f319217b831b5d6ce7459711826af7 (patch) | |
tree | 4033fab5755ce0652d957f3ac7598d841cd2da13 /vespa-athenz | |
parent | bc3ccdb3552d0d3ff5dcc463308614e72e6abd3e (diff) |
Simplify certificate refresh
Diffstat (limited to 'vespa-athenz')
4 files changed, 93 insertions, 224 deletions
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java index c5dce1c5b1d..8127ac9feb3 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java @@ -3,7 +3,6 @@ package com.yahoo.vespa.athenz.identityprovider; import java.security.KeyPair; import java.security.cert.X509Certificate; -import java.time.Instant; /** * @author bjorncs @@ -14,18 +13,15 @@ class AthenzCredentials { private final X509Certificate certificate; private final KeyPair keyPair; private final SignedIdentityDocument identityDocument; - private final Instant createdAt; AthenzCredentials(String nToken, X509Certificate certificate, KeyPair keyPair, - SignedIdentityDocument identityDocument, - Instant createdAt) { + SignedIdentityDocument identityDocument) { this.nToken = nToken; this.certificate = certificate; this.keyPair = keyPair; this.identityDocument = identityDocument; - this.createdAt = createdAt; } String getNToken() { @@ -44,8 +40,5 @@ class AthenzCredentials { return identityDocument; } - Instant getCreatedAt() { - return createdAt; - } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java index dd816929bfb..b9fb7e94782 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java @@ -79,7 +79,7 @@ class AthenzCredentialsService { SignedIdentityDocument identityDocument) { X509Certificate certificate = instanceIdentity.getX509Certificate(); String serviceToken = instanceIdentity.getServiceToken(); - return new AthenzCredentials(serviceToken, certificate, keyPair, identityDocument, clock.instant()); + return new AthenzCredentials(serviceToken, certificate, keyPair, identityDocument); } private static SignedIdentityDocument parseSignedIdentityDocument(String rawDocument) { diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java index 95113e1b0b1..12f2ce5f074 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.time.Instant; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; @@ -33,35 +34,19 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen // TODO Make some of these values configurable through config. Match requested expiration of register/update requests. // TODO These should match the requested expiration - static final Duration EXPIRES_AFTER = Duration.ofDays(1); - static final Duration EXPIRATION_MARGIN = Duration.ofMinutes(30); - static final Duration INITIAL_WAIT_NTOKEN = Duration.ofMinutes(5); - static final Duration UPDATE_PERIOD = EXPIRES_AFTER.dividedBy(3); - static final Duration REDUCED_UPDATE_PERIOD = Duration.ofMinutes(30); - static final Duration INITIAL_BACKOFF_DELAY = Duration.ofMinutes(4); - static final Duration MAX_REGISTER_BACKOFF_DELAY = Duration.ofHours(1); - static final int BACKOFF_DELAY_MULTIPLIER = 2; + static final Duration UPDATE_PERIOD = Duration.ofDays(1); static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90); - private static final Duration CERTIFICATE_EXPIRY_METRIC_UPDATE_PERIOD = Duration.ofMinutes(5); - private static final String CERTIFICATE_EXPIRY_METRIC_NAME = "athenz-tenant-cert.expiry.seconds"; - - static final String REGISTER_INSTANCE_TAG = "register-instance"; - static final String UPDATE_CREDENTIALS_TAG = "update-credentials"; - static final String TIMEOUT_INITIAL_WAIT_TAG = "timeout-initial-wait"; - static final String METRICS_UPDATER_TAG = "metrics-updater"; - + public static final String CERTIFICATE_EXPIRY_METRIC_NAME = "athenz-tenant-cert.expiry.seconds"; private volatile AthenzCredentials credentials; - private final AtomicReference<Throwable> lastThrowable = new AtomicReference<>(); + private final Metric metric; private final AthenzCredentialsService athenzCredentialsService; private final Scheduler scheduler; private final Clock clock; private final String domain; private final String service; - private final CertificateExpiryMetricUpdater metricUpdater; - @Inject public AthenzIdentityProviderImpl(IdentityConfig config, Metric metric) { this(config, @@ -80,20 +65,20 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen AthenzCredentialsService athenzCredentialsService, Scheduler scheduler, Clock clock) { + this.metric = metric; this.athenzCredentialsService = athenzCredentialsService; this.scheduler = scheduler; this.clock = clock; this.domain = config.domain(); this.service = config.service(); - metricUpdater = new CertificateExpiryMetricUpdater(metric); registerInstance(); } private void registerInstance() { try { credentials = athenzCredentialsService.registerInstance(); - scheduler.schedule(new UpdateCredentialsTask(), UPDATE_PERIOD); - scheduler.submit(metricUpdater); + scheduler.scheduleAtFixedRate(new CertificateExpiryMetricUpdater(), Duration.ofMinutes(0), Duration.ofMinutes(5)); + scheduler.scheduleAtFixedRate(new UpdateCredentialsTask(), UPDATE_PERIOD, UPDATE_PERIOD); } catch (Throwable t) { throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", t); } @@ -129,96 +114,66 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen } private static Instant getExpirationTime(AthenzCredentials credentials) { - return credentials.getCreatedAt().plus(EXPIRES_AFTER).minus(EXPIRATION_MARGIN); + return credentials.getCertificate().getNotAfter().toInstant(); } - private class UpdateCredentialsTask implements RunnableWithTag { - @Override - public void run() { - try { - AthenzCredentials newCredentials = isExpired(credentials) - ? athenzCredentialsService.registerInstance() - : athenzCredentialsService.updateCredentials(credentials); - credentials = newCredentials; - scheduler.schedule(new UpdateCredentialsTask(), UPDATE_PERIOD); - } catch (Throwable t) { - log.log(LogLevel.WARNING, "Failed to update credentials: " + t.getMessage(), t); - lastThrowable.set(t); - Duration timeToExpiration = Duration.between(clock.instant(), getExpirationTime(credentials)); - // NOTE: Update period might be after timeToExpiration, still we do not want to DDoS Athenz. - Duration updatePeriod = - timeToExpiration.compareTo(UPDATE_PERIOD) > 0 ? UPDATE_PERIOD : REDUCED_UPDATE_PERIOD; - scheduler.schedule(new UpdateCredentialsTask(), updatePeriod); - } - } - - @Override - public String tag() { - return UPDATE_CREDENTIALS_TAG; + void refreshCertificate() { + try { + AthenzCredentials newCredentials = isExpired(credentials) + ? athenzCredentialsService.registerInstance() + : athenzCredentialsService.updateCredentials(credentials); + credentials = newCredentials; + } catch (Throwable t) { + log.log(LogLevel.WARNING, "Failed to update credentials: " + t.getMessage(), t); } } - private class CertificateExpiryMetricUpdater implements RunnableWithTag { - private final Metric metric; - - private CertificateExpiryMetricUpdater(Metric metric) { - this.metric = metric; - } - - @Override - public void run() { + void reportMetrics() { + try { Instant expirationTime = getExpirationTime(credentials); Duration remainingLifetime = Duration.between(clock.instant(), expirationTime); metric.set(CERTIFICATE_EXPIRY_METRIC_NAME, remainingLifetime.getSeconds(), null); - scheduler.schedule(this, CERTIFICATE_EXPIRY_METRIC_UPDATE_PERIOD); + } catch (Throwable t) { + log.log(LogLevel.WARNING, "Failed to update metrics: " + t.getMessage(), t); } + } + private class UpdateCredentialsTask implements Runnable { @Override - public String tag() { - return METRICS_UPDATER_TAG; + public void run() { + refreshCertificate(); } } - private static class ThreadPoolScheduler implements Scheduler { - - private static final Logger log = Logger.getLogger(ThreadPoolScheduler.class.getName()); - - private final ScheduledExecutorService executor = Executors.newScheduledThreadPool(0); - + private class CertificateExpiryMetricUpdater implements Runnable { @Override - public void schedule(RunnableWithTag runnable, Duration delay) { - log.log(LogLevel.FINE, String.format("Scheduling task '%s' in '%s'", runnable.tag(), delay)); - executor.schedule(runnable, delay.getSeconds(), TimeUnit.SECONDS); + public void run() { + reportMetrics(); } + } + + private static class ThreadPoolScheduler implements Scheduler { + private final ScheduledExecutorService executor = Executors.newScheduledThreadPool(1); @Override - public void submit(RunnableWithTag runnable) { - log.log(LogLevel.FINE, String.format("Scheduling task '%s' now", runnable.tag())); - executor.submit(runnable); + public void scheduleAtFixedRate(Runnable runnable, Duration initialDelay, Duration period) { + executor.scheduleAtFixedRate(runnable, initialDelay.getSeconds(), period.getSeconds(), TimeUnit.SECONDS); } @Override public void shutdown(Duration timeout) { try { executor.shutdownNow(); - executor.awaitTermination(AWAIT_TERMINTATION_TIMEOUT.getSeconds(), TimeUnit.SECONDS); + executor.awaitTermination(timeout.getSeconds(), TimeUnit.SECONDS); } catch (InterruptedException e) { throw new RuntimeException(e); } } - } public interface Scheduler { - void schedule(RunnableWithTag runnable, Duration delay); - default void submit(RunnableWithTag runnable) { schedule(runnable, Duration.ZERO); } - default void shutdown(Duration timeout) {} + void scheduleAtFixedRate(Runnable runnable, Duration initialDelay, Duration period); + void shutdown(Duration timeout); } - - public interface RunnableWithTag extends Runnable { - - String tag(); - } - } diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java index d9dbd73a94e..29b122b09e5 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java @@ -6,30 +6,22 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; import com.yahoo.jdisc.Metric; import com.yahoo.test.ManualClock; -import com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.RunnableWithTag; import com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.Scheduler; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.PriorityQueue; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Predicate; - -import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.METRICS_UPDATER_TAG; -import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD; -import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG; -import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_PERIOD; -import static org.junit.Assert.assertEquals; +import java.util.Date; +import java.util.function.Supplier; + import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -38,72 +30,79 @@ import static org.mockito.Mockito.when; */ public class AthenzIdentityProviderImplTest { - private static final Metric DUMMY_METRIC = new Metric() { - @Override - public void set(String s, Number number, Context context) { - } - - @Override - public void add(String s, Number number, Context context) { - } - - @Override - public Context createContext(Map<String, ?> stringMap) { - return null; - } - }; + public static final Duration certificateValidity = Duration.ofDays(30); private static final IdentityConfig IDENTITY_CONFIG = new IdentityConfig(new IdentityConfig.Builder() .service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg")); - @Test (expected = AthenzIdentityProviderException.class) + @Test(expected = AthenzIdentityProviderException.class) public void component_creation_fails_when_credentials_not_found() { AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class); when(credentialService.registerInstance()) .thenThrow(new RuntimeException("athenz unavailable")); - ManualClock clock = new ManualClock(Instant.EPOCH); - MockScheduler scheduler = new MockScheduler(clock); - AthenzIdentityProvider identityProvider = - new AthenzIdentityProviderImpl(IDENTITY_CONFIG, DUMMY_METRIC, credentialService, scheduler, clock); + new AthenzIdentityProviderImpl(IDENTITY_CONFIG, mock(Metric.class), credentialService, mock(Scheduler.class), new ManualClock(Instant.EPOCH)); } @Test - public void failed_credentials_updates_will_schedule_retries() { + public void metrics_updated_on_refresh() { IdentityDocumentService identityDocumentService = mock(IdentityDocumentService.class); AthenzService athenzService = mock(AthenzService.class); ManualClock clock = new ManualClock(Instant.EPOCH); - MockScheduler scheduler = new MockScheduler(clock); - X509Certificate x509Certificate = mock(X509Certificate.class); + Metric metric = mock(Metric.class); when(identityDocumentService.getSignedIdentityDocument()).thenReturn(getIdentityDocument()); - when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn( - new InstanceIdentity(null, "TOKEN")); + when(athenzService.sendInstanceRegisterRequest(any(), any())).then(new Answer<InstanceIdentity>() { + @Override + public InstanceIdentity answer(InvocationOnMock invocationOnMock) throws Throwable { + return new InstanceIdentity(getCertificate(getExpirationSupplier(clock)), "TOKEN"); + } + }); + when(athenzService.sendInstanceRefreshRequest(anyString(), anyString(), anyString(), anyString(), any(), any(), any(), any())) .thenThrow(new RuntimeException("#1")) .thenThrow(new RuntimeException("#2")) - .thenThrow(new RuntimeException("#3")) - .thenReturn(new InstanceIdentity(null, "TOKEN")); + .thenReturn(new InstanceIdentity(getCertificate(getExpirationSupplier(clock)), "TOKEN")); + AthenzCredentialsService credentialService = new AthenzCredentialsService(IDENTITY_CONFIG, identityDocumentService, athenzService, clock); - AthenzIdentityProvider identityProvider = - new AthenzIdentityProviderImpl(IDENTITY_CONFIG, DUMMY_METRIC, credentialService, scheduler, clock); - - List<MockScheduler.CompletedTask> expectedTasks = - Arrays.asList( - new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD), - new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD), - new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD), - new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD), - new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD)); - AtomicInteger counter = new AtomicInteger(0); - List<MockScheduler.CompletedTask> completedTasks = - scheduler.runAllTasks(task -> !task.tag().equals(METRICS_UPDATER_TAG) && - counter.getAndIncrement() < expectedTasks.size()); - assertEquals(expectedTasks, completedTasks); + AthenzIdentityProviderImpl identityProvider = + new AthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, credentialService, mock(Scheduler.class), clock); + + identityProvider.reportMetrics(); + verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any()); + + // Advance 1 day, refresh fails, cert is 1 day old + clock.advance(Duration.ofDays(1)); + identityProvider.refreshCertificate(); + identityProvider.reportMetrics(); + verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(1)).getSeconds()), any()); + + // Advance 1 more day, refresh fails, cert is 2 days old + clock.advance(Duration.ofDays(1)); + identityProvider.refreshCertificate(); + identityProvider.reportMetrics(); + verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(2)).getSeconds()), any()); + + // Advance 1 more day, refresh succeds, cert is new + clock.advance(Duration.ofDays(1)); + identityProvider.refreshCertificate(); + identityProvider.reportMetrics(); + verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any()); + + } + + private Supplier<Date> getExpirationSupplier(ManualClock clock) { + return () -> new Date(clock.instant().plus(certificateValidity).toEpochMilli()); + } + + private X509Certificate getCertificate(Supplier<Date> expiry) { + X509Certificate x509Certificate = mock(X509Certificate.class); + when(x509Certificate.getNotAfter()).thenReturn(expiry.get()); + return x509Certificate; } private static String getIdentityDocument() { @@ -119,82 +118,4 @@ public class AthenzIdentityProviderImplTest { "}"; } - - private static class MockScheduler implements Scheduler { - - private final PriorityQueue<DelayedTask> tasks = new PriorityQueue<>(); - private final ManualClock clock; - - MockScheduler(ManualClock clock) { - this.clock = clock; - } - - @Override - public void schedule(RunnableWithTag task, Duration delay) { - tasks.offer(new DelayedTask(task, delay, clock.instant().plus(delay))); - } - - List<CompletedTask> runAllTasks(Predicate<RunnableWithTag> filter) { - List<CompletedTask> completedTasks = new ArrayList<>(); - while (!tasks.isEmpty()) { - DelayedTask task = tasks.poll(); - RunnableWithTag runnable = task.runnableWithTag; - if (filter.test(runnable)) { - clock.setInstant(task.startTime); - runnable.run(); - completedTasks.add(new CompletedTask(runnable.tag(), task.delay)); - } - } - return completedTasks; - } - - private static class DelayedTask implements Comparable<DelayedTask> { - final RunnableWithTag runnableWithTag; - final Duration delay; - final Instant startTime; - - DelayedTask(RunnableWithTag runnableWithTag, Duration delay, Instant startTime) { - this.runnableWithTag = runnableWithTag; - this.delay = delay; - this.startTime = startTime; - } - - @Override - public int compareTo(DelayedTask other) { - return this.startTime.compareTo(other.startTime); - } - } - - private static class CompletedTask { - final String tag; - final Duration delay; - - CompletedTask(String tag, Duration delay) { - this.tag = tag; - this.delay = delay; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - CompletedTask that = (CompletedTask) o; - return Objects.equals(tag, that.tag) && - Objects.equals(delay, that.delay); - } - - @Override - public int hashCode() { - return Objects.hash(tag, delay); - } - - @Override - public String toString() { - return "CompletedTask{" + - "tag='" + tag + '\'' + - ", delay=" + delay + - '}'; - } - } - } } |