diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-12-31 13:47:29 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-12-31 13:49:02 +0100 |
commit | bf6a8dc976bd100030240935e4047bb58008a02b (patch) | |
tree | 61e84a5cf5048bc7ac0ef1198ad1c1cde1471f60 | |
parent | b511d4c207c06621b2d3585ac33fa8d97acfcac0 (diff) |
Fix deadlock
-rw-r--r-- | cache/cache.go | 8 | ||||
-rw-r--r-- | cache/cache_test.go | 43 |
2 files changed, 34 insertions, 17 deletions
diff --git a/cache/cache.go b/cache/cache.go index 011b6c3..80a9460 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -87,7 +87,7 @@ func (c *Cache) getValue(key uint64) (*Value, bool) { } if c.isExpired(v) { if !c.prefetch() { - go c.evict(key) + go c.evictWithLock(key) return nil, false } // Refresh and return a stale value @@ -167,9 +167,13 @@ func (c *Cache) refresh(key uint64, old *dns.Msg) { } } -func (c *Cache) evict(key uint64) { +func (c *Cache) evictWithLock(key uint64) { c.mu.Lock() defer c.mu.Unlock() + c.evict(key) +} + +func (c *Cache) evict(key uint64) { delete(c.values, key) var keys []uint64 for _, k := range c.keys { diff --git a/cache/cache_test.go b/cache/cache_test.go index 74e0b16..541ae10 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -42,7 +42,7 @@ func reverse(msgs []*dns.Msg) []*dns.Msg { return reversed } -func awaitExpiry(t *testing.T, c *Cache, k uint64) { +func awaitExpiry(t *testing.T, i int, c *Cache, k uint64) { now := time.Now() for { // Loop until k is removed by maintainer c.mu.RLock() @@ -53,12 +53,12 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) { } time.Sleep(10 * time.Millisecond) if time.Since(now) > 2*time.Second { - t.Fatalf("timed out waiting for expiry of key %d", k) + t.Fatalf("#%d: timed out waiting for expiry of key %d", i, k) } } } -func awaitRefresh(t *testing.T, c *Cache, k uint64, u time.Time) { +func awaitRefresh(t *testing.T, i int, c *Cache, k uint64, u time.Time) { now := time.Now() for { // Loop until CreatedAt of key k is after u v, ok := c.getValue(k) @@ -67,7 +67,7 @@ func awaitRefresh(t *testing.T, c *Cache, k uint64, u time.Time) { } time.Sleep(10 * time.Millisecond) if time.Since(now) > 2*time.Second { - t.Fatalf("timed out waiting for refresh of key %d", k) + t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k) } } } @@ -130,7 +130,7 @@ func TestCache(t *testing.T) { t.Errorf("#%d: getValue(%d) = (%+v, %t), want (%+v, %t)", i, k, v, ok, tt.value, tt.ok) } if !tt.ok { - awaitExpiry(t, c, k) + awaitExpiry(t, i, c, k) } if _, ok := c.values[k]; ok != tt.ok { t.Errorf("#%d: values[%d] = %t, want %t", i, k, ok, tt.ok) @@ -245,23 +245,27 @@ func TestCachePrefetch(t *testing.T) { var tests = []struct { initialAnswer string refreshAnswer string - ttl time.Duration + initialTTL time.Duration + refreshTTL time.Duration queryTime time.Time answer string ok bool - awaitRefresh bool + awaitUpdate bool }{ // Serves cached value before expiry - {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true}, + {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true}, // Serves stale cached value after expiry and before refresh happens - {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false}, + {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false}, // Serves refreshed value after expiry and refresh - {"192.0.2.1", "192.0.2.42", time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true}, + {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true}, + // Refreshed value can no longer be cached + {"192.0.2.1", "192.0.2.42", time.Minute, 0, now.Add(61 * time.Second), "192.0.2.42", false, true}, } for i, tt := range tests { - msg := newA("example.com.", uint32(tt.ttl.Seconds()), net.ParseIP(tt.initialAnswer)) + msg := newA("example.com.", uint32(tt.initialTTL.Seconds()), net.ParseIP(tt.initialAnswer)) copy := msg.Copy() copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer) + copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds()) exchanger.answer = copy c.now = func() time.Time { return now } @@ -270,13 +274,22 @@ func TestCachePrefetch(t *testing.T) { c.now = func() time.Time { return tt.queryTime } v, ok := c.getValue(key) - if tt.awaitRefresh && c.isExpired(v) { - awaitRefresh(t, c, key, v.CreatedAt) + if tt.awaitUpdate { + if !tt.ok { + awaitExpiry(t, i, c, key) + ok = false + } else if c.isExpired(v) { + awaitRefresh(t, i, c, key, v.CreatedAt) + } + } + + if ok != tt.ok { + t.Errorf("#%d: Get(%d) = (_, %t), want (_, %t)", i, key, ok, tt.ok) } answers := dnsutil.Answers(v.msg) - if answers[0] != tt.answer || ok != tt.ok { - t.Errorf("#%d: Get(%d) = (%q, %t), want (%q, %t)", i, key, answers[0], ok, tt.answer, tt.ok) + if tt.ok && answers[0] != tt.answer { + t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer) } } } |