aboutsummaryrefslogtreecommitdiffstats
path: root/http-utils/src/main/java/ai/vespa/util/http/retry/DelayedHttpRequestRetryHandler.java
blob: 72bb171c4c72e7d0cadfa16a2fb367c31eab6391 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.util.http.retry;

import org.apache.http.annotation.Contract;
import org.apache.http.annotation.ThreadingBehavior;
import org.apache.http.client.HttpRequestRetryHandler;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.protocol.HttpContext;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import java.util.logging.Logger;

/**
 * A {@link HttpRequestRetryHandler} that supports delayed retries.
 *
 * @author bjorncs
 */
@Contract(threading = ThreadingBehavior.IMMUTABLE)
public class DelayedHttpRequestRetryHandler implements HttpRequestRetryHandler {

    private static final Logger log = Logger.getLogger(HttpRequestRetryHandler.class.getName());

    @FunctionalInterface
    public interface RetryConsumer {
        void onRetry(IOException exception, Duration delay, int executionCount, HttpClientContext context);
    }

    @FunctionalInterface
    public interface RetryFailedConsumer {
        void onRetryFailed(IOException exception, int executionCount, HttpClientContext context);
    }

    @FunctionalInterface
    public interface RetryPredicate extends BiPredicate<IOException, HttpClientContext> {}

    private final DelaySupplier delaySupplier;
    private final int maxRetries;
    private final RetryPredicate predicate;
    private final RetryConsumer retryConsumer;
    private final RetryFailedConsumer retryFailedConsumer;
    private final Sleeper sleeper;

    private DelayedHttpRequestRetryHandler(
            DelaySupplier delaySupplier,
            int maxRetries,
            RetryPredicate predicate,
            RetryConsumer retryConsumer,
            RetryFailedConsumer retryFailedConsumer,
            Sleeper sleeper) {
        this.delaySupplier = delaySupplier;
        this.maxRetries = maxRetries;
        this.predicate = predicate;
        this.retryConsumer = retryConsumer;
        this.retryFailedConsumer = retryFailedConsumer;
        this.sleeper = sleeper;
    }

    @Override
    public boolean retryRequest(IOException exception, int executionCount, HttpContext ctx) {
        log.fine(() -> String.format("retryRequest(exception='%s', executionCount='%d', ctx='%s'",
                                     exception.getClass().getName(), executionCount, ctx));
        HttpClientContext clientCtx = HttpClientContext.adapt(ctx);
        if (!predicate.test(exception, clientCtx)) {
            log.fine(() -> String.format("Not retrying for '%s'", ctx));
            return false;
        }
        if (executionCount > maxRetries) {
            log.fine(() -> String.format("Max retries exceeded for '%s'", ctx));
            retryFailedConsumer.onRetryFailed(exception, executionCount, clientCtx);
            return false;
        }
        Duration delay = delaySupplier.getDelay(executionCount);
        log.fine(() -> String.format("Retrying after %s for '%s'", delay, ctx));
        retryConsumer.onRetry(exception, delay, executionCount, clientCtx);
        sleeper.sleep(delay);
        return true;
    }

    public static class Builder {

        private final DelaySupplier delaySupplier;
        private final int maxRetries;
        private RetryPredicate predicate = (ioException, ctx) -> true;
        private RetryConsumer retryConsumer = (exception, delay, count, ctx) -> {};
        private RetryFailedConsumer retryFailedConsumer = (exception, count, ctx) -> {};
        private Sleeper sleeper = new DefaultSleeper();

        private Builder(DelaySupplier delaySupplier, int maxRetries) {
            this.delaySupplier = delaySupplier;
            this.maxRetries = maxRetries;
        }

        public static Builder withFixedDelay(Duration delay, int maxRetries) {
            return new Builder(executionCount -> delay, maxRetries);
        }

        public static Builder withExponentialBackoff(Duration startDelay, Duration maxDelay, int maxRetries) {
            return new Builder(
                    executionCount -> {
                        Duration nextDelay = startDelay;
                        for (int i = 1; i < executionCount; ++i) {
                            nextDelay = nextDelay.multipliedBy(2);
                        }
                        return maxDelay.compareTo(nextDelay) > 0 ? nextDelay : maxDelay;
                    },
                    maxRetries);
        }

        public Builder retryForExceptions(List<Class<? extends IOException>> exceptionTypes) {
            this.predicate = (ioException, ctx) -> exceptionTypes.stream().anyMatch(type -> type.isInstance(ioException));
            return this;
        }

        public Builder retryForExceptions(Predicate<IOException> predicate) {
            this.predicate = (ioException, ctx) -> predicate.test(ioException);
            return this;
        }

        public Builder retryFor(RetryPredicate predicate) {
            this.predicate = predicate;
            return this;
        }

        public Builder onRetry(RetryConsumer consumer) {
            this.retryConsumer = consumer;
            return this;
        }

        public Builder onRetryFailed(RetryFailedConsumer consumer) {
            this.retryFailedConsumer = consumer;
            return this;
        }

        // For unit testing
        Builder withSleeper(Sleeper sleeper) {
            this.sleeper = sleeper;
            return this;
        }

        public DelayedHttpRequestRetryHandler build() {
            return new DelayedHttpRequestRetryHandler(delaySupplier, maxRetries, predicate, retryConsumer, retryFailedConsumer, sleeper);
        }

        private static class DefaultSleeper implements Sleeper {
            @Override
            public void sleep(Duration duration) {
                try {
                    Thread.sleep(duration.toMillis());
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    // For unit testing
    interface Sleeper {
        void sleep(Duration duration);
    }

    @FunctionalInterface
    private interface DelaySupplier {
        Duration getDelay(int executionCount);
    }
}