diff options
-rw-r--r-- | cache/cache.go | 27 | ||||
-rw-r--r-- | cache/cache_test.go | 52 |
2 files changed, 50 insertions, 29 deletions
diff --git a/cache/cache.go b/cache/cache.go index 0f4dde4..efa404d 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -38,6 +38,8 @@ type Cache struct { keys []uint32 mu sync.RWMutex now func() time.Time + queue chan func() + wg sync.WaitGroup } // Value wraps a DNS message stored in the cache. @@ -139,8 +141,10 @@ func newCache(capacity int, client *dnsutil.Client, backend Backend, now func() now: now, capacity: capacity, values: make(map[uint32]Value, capacity), + queue: make(chan func(), 1024), } c.load(backend) + go c.readQueue() return c } @@ -176,6 +180,12 @@ func (c *Cache) load(backend Backend) { c.backend = backend } +// Close consumes any outstanding cache operations. +func (c *Cache) Close() error { + c.wg.Wait() + return nil +} + // Get returns the DNS message associated with key. func (c *Cache) Get(key uint32) (*dns.Msg, bool) { v, ok := c.getValue(key) @@ -194,11 +204,10 @@ func (c *Cache) getValue(key uint32) (*Value, bool) { } if c.isExpired(&v) { if !c.prefetch() { - go c.evictWithLock(key) + c.enqueue(func() { c.evictWithLock(key) }) return nil, false } - // Refresh and return a stale value - go c.refresh(key, v.msg) + c.enqueue(func() { c.refresh(key, v.msg) }) } return &v, true } @@ -320,6 +329,18 @@ func (c *Cache) isExpired(v *Value) bool { return c.now().After(expiresAt) } +func (c *Cache) enqueue(op func()) { + c.wg.Add(1) + c.queue <- op +} + +func (c *Cache) readQueue() { + for op := range c.queue { + op() + c.wg.Done() + } +} + func canCache(msg *dns.Msg) bool { if dnsutil.MinTTL(msg) == 0 { return false diff --git a/cache/cache_test.go b/cache/cache_test.go index e571f73..af88819 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -107,11 +107,10 @@ func awaitRefresh(t *testing.T, i int, c *Cache, k uint32, u time.Time) { for { // Loop until CreatedAt of key k is after u c.mu.RLock() v, ok := c.values[k] - if ok && v.CreatedAt.After(u) { - c.mu.RUnlock() + c.mu.RUnlock() + if ok && !v.CreatedAt.Before(u) { break } - c.mu.RUnlock() time.Sleep(10 * time.Millisecond) if time.Since(now) > 2*time.Second { t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k) @@ -290,25 +289,24 @@ func TestCachePrefetch(t *testing.T) { client := &dnsutil.Client{Exchanger: exchanger, Addresses: []string{"resolver"}} now := time.Now() c := newCache(10, client, &defaultBackend{}, func() time.Time { return now }) - var tests = []struct { initialAnswer string refreshAnswer string initialTTL time.Duration refreshTTL time.Duration - queryTime time.Time + readDelay time.Duration answer string ok bool - awaitUpdate bool + refetch bool }{ // Serves cached value before expiry - {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true}, + {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 30 * time.Second, "192.0.2.1", true, true}, // Serves stale cached value after expiry and before refresh happens - {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false}, + {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 61 * time.Second, "192.0.2.1", true, false}, // Serves refreshed value after expiry and refresh - {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true}, + {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 61 * time.Second, "192.0.2.42", true, true}, // Refreshed value can no longer be cached - {"192.0.2.1", "192.0.2.42", time.Minute, 0, now.Add(61 * time.Second), "192.0.2.42", false, true}, + {"192.0.2.1", "192.0.2.42", time.Minute, 0, 61 * time.Second, "192.0.2.42", false, true}, } for i, tt := range tests { msg := newA("example.com.", uint32(tt.initialTTL.Seconds()), net.ParseIP(tt.initialAnswer)) @@ -317,30 +315,32 @@ func TestCachePrefetch(t *testing.T) { copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds()) exchanger.reset() exchanger.setAnswer(copy) - c.now = func() time.Time { return now } + // Add new value now + c.mu.Lock() + c.now = func() time.Time { return now } + c.mu.Unlock() var key uint32 = 1 c.Set(key, msg) - c.now = func() time.Time { return tt.queryTime } + + // Read value at some point in the future + c.mu.Lock() + c.now = func() time.Time { return now.Add(tt.readDelay) } + c.mu.Unlock() v, ok := c.getValue(key) + c.Close() // Flush queued operations - if tt.awaitUpdate { - if !tt.ok { - awaitExpiry(t, i, c, key) - ok = false - } else if c.isExpired(v) { - awaitRefresh(t, i, c, key, v.CreatedAt) - v, ok = c.getValue(key) - } + if tt.refetch { + v, ok = c.getValue(key) } - if ok != tt.ok { t.Errorf("#%d: Get(%d) = (_, %t), want (_, %t)", i, key, ok, tt.ok) } - - answers := dnsutil.Answers(v.msg) - if tt.ok && answers[0] != tt.answer { - t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer) + if tt.ok { + answers := dnsutil.Answers(v.msg) + if answers[0] != tt.answer { + t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer) + } } } } @@ -371,7 +371,7 @@ func TestCacheEvictAndUpdate(t *testing.T) { c.Get(key) // Last query refreshes key - awaitRefresh(t, 0, c, key, c.now().Add(-time.Second)) + awaitRefresh(t, 0, c, key, c.now()) keyExists := false for _, k := range c.keys { if k == key { |