aboutsummaryrefslogtreecommitdiffstats
path: root/cache/cache_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'cache/cache_test.go')
-rw-r--r--cache/cache_test.go43
1 files changed, 28 insertions, 15 deletions
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 74e0b16..541ae10 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -42,7 +42,7 @@ func reverse(msgs []*dns.Msg) []*dns.Msg {
return reversed
}
-func awaitExpiry(t *testing.T, c *Cache, k uint64) {
+func awaitExpiry(t *testing.T, i int, c *Cache, k uint64) {
now := time.Now()
for { // Loop until k is removed by maintainer
c.mu.RLock()
@@ -53,12 +53,12 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) {
}
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
- t.Fatalf("timed out waiting for expiry of key %d", k)
+ t.Fatalf("#%d: timed out waiting for expiry of key %d", i, k)
}
}
}
-func awaitRefresh(t *testing.T, c *Cache, k uint64, u time.Time) {
+func awaitRefresh(t *testing.T, i int, c *Cache, k uint64, u time.Time) {
now := time.Now()
for { // Loop until CreatedAt of key k is after u
v, ok := c.getValue(k)
@@ -67,7 +67,7 @@ func awaitRefresh(t *testing.T, c *Cache, k uint64, u time.Time) {
}
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
- t.Fatalf("timed out waiting for refresh of key %d", k)
+ t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k)
}
}
}
@@ -130,7 +130,7 @@ func TestCache(t *testing.T) {
t.Errorf("#%d: getValue(%d) = (%+v, %t), want (%+v, %t)", i, k, v, ok, tt.value, tt.ok)
}
if !tt.ok {
- awaitExpiry(t, c, k)
+ awaitExpiry(t, i, c, k)
}
if _, ok := c.values[k]; ok != tt.ok {
t.Errorf("#%d: values[%d] = %t, want %t", i, k, ok, tt.ok)
@@ -245,23 +245,27 @@ func TestCachePrefetch(t *testing.T) {
var tests = []struct {
initialAnswer string
refreshAnswer string
- ttl time.Duration
+ initialTTL time.Duration
+ refreshTTL time.Duration
queryTime time.Time
answer string
ok bool
- awaitRefresh bool
+ awaitUpdate bool
}{
// Serves cached value before expiry
- {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true},
// Serves stale cached value after expiry and before refresh happens
- {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false},
// Serves refreshed value after expiry and refresh
- {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true},
+ // Refreshed value can no longer be cached
+ {"192.0.2.1", "192.0.2.42", time.Minute, 0, now.Add(61 * time.Second), "192.0.2.42", false, true},
}
for i, tt := range tests {
- msg := newA("example.com.", uint32(tt.ttl.Seconds()), net.ParseIP(tt.initialAnswer))
+ msg := newA("example.com.", uint32(tt.initialTTL.Seconds()), net.ParseIP(tt.initialAnswer))
copy := msg.Copy()
copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer)
+ copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds())
exchanger.answer = copy
c.now = func() time.Time { return now }
@@ -270,13 +274,22 @@ func TestCachePrefetch(t *testing.T) {
c.now = func() time.Time { return tt.queryTime }
v, ok := c.getValue(key)
- if tt.awaitRefresh && c.isExpired(v) {
- awaitRefresh(t, c, key, v.CreatedAt)
+ if tt.awaitUpdate {
+ if !tt.ok {
+ awaitExpiry(t, i, c, key)
+ ok = false
+ } else if c.isExpired(v) {
+ awaitRefresh(t, i, c, key, v.CreatedAt)
+ }
+ }
+
+ if ok != tt.ok {
+ t.Errorf("#%d: Get(%d) = (_, %t), want (_, %t)", i, key, ok, tt.ok)
}
answers := dnsutil.Answers(v.msg)
- if answers[0] != tt.answer || ok != tt.ok {
- t.Errorf("#%d: Get(%d) = (%q, %t), want (%q, %t)", i, key, answers[0], ok, tt.answer, tt.ok)
+ if tt.ok && answers[0] != tt.answer {
+ t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer)
}
}
}