diff options
author | Bjørn Christian Seime <bjorncs@oath.com> | 2017-11-06 17:45:32 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@oath.com> | 2017-11-06 17:49:17 +0100 |
commit | c14153031e48b71bf4b7c66fab88cb37b8d49788 (patch) | |
tree | bb326f9e6c8ff7c01625c598c496f77d1f7ee2b3 /container-disc/src/test/java | |
parent | 145d7af877637a1c1bd024cc9fedf6de644ba584 (diff) |
Add token refresh to AthenzIdentityProvider implementation
Includes logic for retry: exponential backoff for instance registration,
linear backoff for credential updates. Moved instance registration +
credentials update logic to new class AthenzCredentialsService.
Diffstat (limited to 'container-disc/src/test/java')
-rw-r--r-- | container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java | 199 |
1 files changed, 191 insertions, 8 deletions
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java index b1c7699364d..d13d86553c6 100644 --- a/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/athenz/impl/AthenzIdentityProviderImplTest.java @@ -3,35 +3,140 @@ package com.yahoo.container.jdisc.athenz.impl; import com.yahoo.container.core.identity.IdentityConfig; import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; -import org.junit.Assert; +import com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.RunnableWithTag; +import com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.Scheduler; +import com.yahoo.test.ManualClock; import org.junit.Test; -import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.INITIAL_BACKOFF_DELAY; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.INITIAL_WAIT_NTOKEN; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.MAX_REGISTER_BACKOFF_DELAY; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.REGISTER_INSTANCE_TAG; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.TIMEOUT_INITIAL_WAIT_TAG; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG; +import static com.yahoo.container.jdisc.athenz.impl.AthenzIdentityProviderImpl.UPDATE_PERIOD; +import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * @author mortent + * @author bjorncs */ public class AthenzIdentityProviderImplTest { + private static final IdentityConfig IDENTITY_CONFIG = + new IdentityConfig(new IdentityConfig.Builder() + .service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg")); + + @Test + public void athenz_credentials_are_retrieved_after_component_contruction_completed() { + ServiceProviderApi serviceProviderApi = mock(ServiceProviderApi.class); + AthenzService athenzService = mock(AthenzService.class); + MockScheduler scheduler = new MockScheduler(); + + when(serviceProviderApi.getSignedIdentityDocument()).thenReturn(getIdentityDocument()); + when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn( + new InstanceIdentity(null, "TOKEN")); + AthenzCredentialsService credentialService = + new AthenzCredentialsService(IDENTITY_CONFIG, serviceProviderApi, athenzService, scheduler.clock()); + + AthenzIdentityProvider identityProvider = + new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, scheduler.clock()); + + List<MockScheduler.CompletedTask> expectedTasks = + Arrays.asList( + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO), + new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN)); + // Don't run update credential tasks, otherwise infinite loop + List<MockScheduler.CompletedTask> completedTasks = + scheduler.runAllTasks(task -> !task.tag().equals(UPDATE_CREDENTIALS_TAG)); + assertEquals(expectedTasks, completedTasks); + assertEquals("TOKEN", identityProvider.getNToken()); + } + + @Test + public void register_instance_uses_exponential_backoff() { + AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class); + when(credentialService.registerInstance()) + .thenThrow(new RuntimeException("#1")) + .thenThrow(new RuntimeException("#2")) + .thenThrow(new RuntimeException("#3")) + .thenThrow(new RuntimeException("#4")) + .thenThrow(new RuntimeException("#5")) + .thenReturn(new AthenzCredentials("TOKEN", null, null, null, Instant.now())); + + MockScheduler scheduler = new MockScheduler(); + AthenzIdentityProvider identityProvider = + new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, scheduler.clock()); + + List<MockScheduler.CompletedTask> expectedTasks = + Arrays.asList( + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO), + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY), + new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN), + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(2)), + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(4)), + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, INITIAL_BACKOFF_DELAY.multipliedBy(8)), + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, MAX_REGISTER_BACKOFF_DELAY)); + // Don't run update credential tasks, otherwise infinite loop + List<MockScheduler.CompletedTask> completedTasks = + scheduler.runAllTasks(task -> !task.tag().equals(UPDATE_CREDENTIALS_TAG)); + assertEquals(expectedTasks, completedTasks); + assertEquals("TOKEN", identityProvider.getNToken()); + } + @Test - public void ntoken_fetched_on_init() throws IOException { - IdentityConfig config = new IdentityConfig(new IdentityConfig.Builder().service("tenantService").domain("tenantDomain").loadBalancerAddress("cfg")); + public void failed_credentials_updates_will_schedule_retries() { ServiceProviderApi serviceProviderApi = mock(ServiceProviderApi.class); AthenzService athenzService = mock(AthenzService.class); + MockScheduler scheduler = new MockScheduler(); when(serviceProviderApi.getSignedIdentityDocument()).thenReturn(getIdentityDocument()); - when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn(new InstanceIdentity(null, "TOKEN")); + when(athenzService.sendInstanceRegisterRequest(any(), any())).thenReturn( + new InstanceIdentity(null, "TOKEN")); + when(athenzService.sendInstanceRefreshRequest(anyString(), anyString(), anyString(), + anyString(), any(), any(), any(), any())) + .thenThrow(new RuntimeException("#1")) + .thenThrow(new RuntimeException("#2")) + .thenThrow(new RuntimeException("#3")) + .thenReturn(new InstanceIdentity(null, "TOKEN")); + AthenzCredentialsService credentialService = + new AthenzCredentialsService(IDENTITY_CONFIG, serviceProviderApi, athenzService, scheduler.clock()); - AthenzIdentityProvider identityProvider = new AthenzIdentityProviderImpl(config, serviceProviderApi, athenzService); + AthenzIdentityProvider identityProvider = + new AthenzIdentityProviderImpl(IDENTITY_CONFIG, credentialService, scheduler, scheduler.clock()); - Assert.assertEquals("TOKEN", identityProvider.getNToken()); + List<MockScheduler.CompletedTask> expectedTasks = + Arrays.asList( + new MockScheduler.CompletedTask(REGISTER_INSTANCE_TAG, Duration.ZERO), + new MockScheduler.CompletedTask(TIMEOUT_INITIAL_WAIT_TAG, INITIAL_WAIT_NTOKEN), + new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD), + new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD), + new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD), + new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, REDUCED_UPDATE_PERIOD), + new MockScheduler.CompletedTask(UPDATE_CREDENTIALS_TAG, UPDATE_PERIOD)); + AtomicInteger counter = new AtomicInteger(0); + List<MockScheduler.CompletedTask> completedTasks = + scheduler.runAllTasks(task -> counter.getAndIncrement() < 7); // 1 registration + 1 timeout + 5 update tasks + assertEquals(expectedTasks, completedTasks); + assertEquals("TOKEN", identityProvider.getNToken()); } - private String getIdentityDocument() { + private static String getIdentityDocument() { return "{\n" + " \"identity-document\": \"eyJwcm92aWRlci11bmlxdWUtaWQiOnsidGVuYW50IjoidGVuYW50IiwiYXBwbGljYXRpb24iOiJhcHBsaWNhdGlvbiIsImVudmlyb25tZW50IjoiZGV2IiwicmVnaW9uIjoidXMtbm9ydGgtMSIsImluc3RhbmNlIjoiZGVmYXVsdCIsImNsdXN0ZXItaWQiOiJkZWZhdWx0IiwiY2x1c3Rlci1pbmRleCI6MH0sImNvbmZpZ3NlcnZlci1ob3N0bmFtZSI6ImxvY2FsaG9zdCIsImluc3RhbmNlLWhvc3RuYW1lIjoieC55LmNvbSIsImNyZWF0ZWQtYXQiOjE1MDg3NDgyODUuNzQyMDAwMDAwfQ==\",\n" + " \"signature\": \"kkEJB/98cy1FeXxzSjtvGH2a6BFgZu/9/kzCcAqRMZjENxnw5jyO1/bjZVzw2Sz4YHPsWSx2uxb32hiQ0U8rMP0zfA9nERIalSP0jB/hMU8laezGhdpk6VKZPJRC6YKAB9Bsv2qUIfMsSxkMqf66GUvjZAGaYsnNa2yHc1jIYHOGMeJO+HNPYJjGv26xPfAOPIKQzs3RmKrc3FoweTCsIwm5oblqekdJvVWYe0obwlOSB5uwc1zpq3Ie1QBFtJRuCGMVHg1pDPxXKBHLClGIrEvzLmICy6IRdHszSO5qiwujUD7sbrbM0sB/u0cYucxbcsGRUmBvme3UAw2mW9POVQ==\",\n" + @@ -44,4 +149,82 @@ public class AthenzIdentityProviderImplTest { "}"; } + + private static class MockScheduler implements Scheduler { + + private final PriorityQueue<DelayedTask> tasks = new PriorityQueue<>(); + private final ManualClock clock = new ManualClock(Instant.EPOCH); + + @Override + public void schedule(RunnableWithTag task, Duration delay) { + tasks.offer(new DelayedTask(task, delay, clock.instant().plus(delay))); + } + + List<CompletedTask> runAllTasks(Predicate<RunnableWithTag> filter) { + List<CompletedTask> completedTasks = new ArrayList<>(); + while (!tasks.isEmpty()) { + DelayedTask task = tasks.poll(); + RunnableWithTag runnable = task.runnableWithTag; + if (filter.test(runnable)) { + clock.setInstant(task.startTime); + runnable.run(); + completedTasks.add(new CompletedTask(runnable.tag(), task.delay)); + } + } + return completedTasks; + } + + public ManualClock clock() { + return clock; + } + + private static class DelayedTask implements Comparable<DelayedTask> { + final RunnableWithTag runnableWithTag; + final Duration delay; + final Instant startTime; + + DelayedTask(RunnableWithTag runnableWithTag, Duration delay, Instant startTime) { + this.runnableWithTag = runnableWithTag; + this.delay = delay; + this.startTime = startTime; + } + + @Override + public int compareTo(DelayedTask other) { + return this.startTime.compareTo(other.startTime); + } + } + + private static class CompletedTask { + final String tag; + final Duration delay; + + CompletedTask(String tag, Duration delay) { + this.tag = tag; + this.delay = delay; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CompletedTask that = (CompletedTask) o; + return Objects.equals(tag, that.tag) && + Objects.equals(delay, that.delay); + } + + @Override + public int hashCode() { + return Objects.hash(tag, delay); + } + + @Override + public String toString() { + return "CompletedTask{" + + "tag='" + tag + '\'' + + ", delay=" + delay + + '}'; + } + } + } } |