aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-31 13:47:29 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-31 13:49:02 +0100
commitbf6a8dc976bd100030240935e4047bb58008a02b (patch)
tree61e84a5cf5048bc7ac0ef1198ad1c1cde1471f60
parentb511d4c207c06621b2d3585ac33fa8d97acfcac0 (diff)
Fix deadlock
-rw-r--r--cache/cache.go8
-rw-r--r--cache/cache_test.go43
2 files changed, 34 insertions, 17 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 011b6c3..80a9460 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -87,7 +87,7 @@ func (c *Cache) getValue(key uint64) (*Value, bool) {
}
if c.isExpired(v) {
if !c.prefetch() {
- go c.evict(key)
+ go c.evictWithLock(key)
return nil, false
}
// Refresh and return a stale value
@@ -167,9 +167,13 @@ func (c *Cache) refresh(key uint64, old *dns.Msg) {
}
}
-func (c *Cache) evict(key uint64) {
+func (c *Cache) evictWithLock(key uint64) {
c.mu.Lock()
defer c.mu.Unlock()
+ c.evict(key)
+}
+
+func (c *Cache) evict(key uint64) {
delete(c.values, key)
var keys []uint64
for _, k := range c.keys {
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 74e0b16..541ae10 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -42,7 +42,7 @@ func reverse(msgs []*dns.Msg) []*dns.Msg {
return reversed
}
-func awaitExpiry(t *testing.T, c *Cache, k uint64) {
+func awaitExpiry(t *testing.T, i int, c *Cache, k uint64) {
now := time.Now()
for { // Loop until k is removed by maintainer
c.mu.RLock()
@@ -53,12 +53,12 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) {
}
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
- t.Fatalf("timed out waiting for expiry of key %d", k)
+ t.Fatalf("#%d: timed out waiting for expiry of key %d", i, k)
}
}
}
-func awaitRefresh(t *testing.T, c *Cache, k uint64, u time.Time) {
+func awaitRefresh(t *testing.T, i int, c *Cache, k uint64, u time.Time) {
now := time.Now()
for { // Loop until CreatedAt of key k is after u
v, ok := c.getValue(k)
@@ -67,7 +67,7 @@ func awaitRefresh(t *testing.T, c *Cache, k uint64, u time.Time) {
}
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
- t.Fatalf("timed out waiting for refresh of key %d", k)
+ t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k)
}
}
}
@@ -130,7 +130,7 @@ func TestCache(t *testing.T) {
t.Errorf("#%d: getValue(%d) = (%+v, %t), want (%+v, %t)", i, k, v, ok, tt.value, tt.ok)
}
if !tt.ok {
- awaitExpiry(t, c, k)
+ awaitExpiry(t, i, c, k)
}
if _, ok := c.values[k]; ok != tt.ok {
t.Errorf("#%d: values[%d] = %t, want %t", i, k, ok, tt.ok)
@@ -245,23 +245,27 @@ func TestCachePrefetch(t *testing.T) {
var tests = []struct {
initialAnswer string
refreshAnswer string
- ttl time.Duration
+ initialTTL time.Duration
+ refreshTTL time.Duration
queryTime time.Time
answer string
ok bool
- awaitRefresh bool
+ awaitUpdate bool
}{
// Serves cached value before expiry
- {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true},
// Serves stale cached value after expiry and before refresh happens
- {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false},
// Serves refreshed value after expiry and refresh
- {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true},
+ // Refreshed value can no longer be cached
+ {"192.0.2.1", "192.0.2.42", time.Minute, 0, now.Add(61 * time.Second), "192.0.2.42", false, true},
}
for i, tt := range tests {
- msg := newA("example.com.", uint32(tt.ttl.Seconds()), net.ParseIP(tt.initialAnswer))
+ msg := newA("example.com.", uint32(tt.initialTTL.Seconds()), net.ParseIP(tt.initialAnswer))
copy := msg.Copy()
copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer)
+ copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds())
exchanger.answer = copy
c.now = func() time.Time { return now }
@@ -270,13 +274,22 @@ func TestCachePrefetch(t *testing.T) {
c.now = func() time.Time { return tt.queryTime }
v, ok := c.getValue(key)
- if tt.awaitRefresh && c.isExpired(v) {
- awaitRefresh(t, c, key, v.CreatedAt)
+ if tt.awaitUpdate {
+ if !tt.ok {
+ awaitExpiry(t, i, c, key)
+ ok = false
+ } else if c.isExpired(v) {
+ awaitRefresh(t, i, c, key, v.CreatedAt)
+ }
+ }
+
+ if ok != tt.ok {
+ t.Errorf("#%d: Get(%d) = (_, %t), want (_, %t)", i, key, ok, tt.ok)
}
answers := dnsutil.Answers(v.msg)
- if answers[0] != tt.answer || ok != tt.ok {
- t.Errorf("#%d: Get(%d) = (%q, %t), want (%q, %t)", i, key, answers[0], ok, tt.answer, tt.ok)
+ if tt.ok && answers[0] != tt.answer {
+ t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer)
}
}
}