aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cache/cache.go27
-rw-r--r--cache/cache_test.go52
2 files changed, 50 insertions, 29 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 0f4dde4..efa404d 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -38,6 +38,8 @@ type Cache struct {
keys []uint32
mu sync.RWMutex
now func() time.Time
+ queue chan func()
+ wg sync.WaitGroup
}
// Value wraps a DNS message stored in the cache.
@@ -139,8 +141,10 @@ func newCache(capacity int, client *dnsutil.Client, backend Backend, now func()
now: now,
capacity: capacity,
values: make(map[uint32]Value, capacity),
+ queue: make(chan func(), 1024),
}
c.load(backend)
+ go c.readQueue()
return c
}
@@ -176,6 +180,12 @@ func (c *Cache) load(backend Backend) {
c.backend = backend
}
+// Close consumes any outstanding cache operations.
+func (c *Cache) Close() error {
+ c.wg.Wait()
+ return nil
+}
+
// Get returns the DNS message associated with key.
func (c *Cache) Get(key uint32) (*dns.Msg, bool) {
v, ok := c.getValue(key)
@@ -194,11 +204,10 @@ func (c *Cache) getValue(key uint32) (*Value, bool) {
}
if c.isExpired(&v) {
if !c.prefetch() {
- go c.evictWithLock(key)
+ c.enqueue(func() { c.evictWithLock(key) })
return nil, false
}
- // Refresh and return a stale value
- go c.refresh(key, v.msg)
+ c.enqueue(func() { c.refresh(key, v.msg) })
}
return &v, true
}
@@ -320,6 +329,18 @@ func (c *Cache) isExpired(v *Value) bool {
return c.now().After(expiresAt)
}
+func (c *Cache) enqueue(op func()) {
+ c.wg.Add(1)
+ c.queue <- op
+}
+
+func (c *Cache) readQueue() {
+ for op := range c.queue {
+ op()
+ c.wg.Done()
+ }
+}
+
func canCache(msg *dns.Msg) bool {
if dnsutil.MinTTL(msg) == 0 {
return false
diff --git a/cache/cache_test.go b/cache/cache_test.go
index e571f73..af88819 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -107,11 +107,10 @@ func awaitRefresh(t *testing.T, i int, c *Cache, k uint32, u time.Time) {
for { // Loop until CreatedAt of key k is after u
c.mu.RLock()
v, ok := c.values[k]
- if ok && v.CreatedAt.After(u) {
- c.mu.RUnlock()
+ c.mu.RUnlock()
+ if ok && !v.CreatedAt.Before(u) {
break
}
- c.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k)
@@ -290,25 +289,24 @@ func TestCachePrefetch(t *testing.T) {
client := &dnsutil.Client{Exchanger: exchanger, Addresses: []string{"resolver"}}
now := time.Now()
c := newCache(10, client, &defaultBackend{}, func() time.Time { return now })
-
var tests = []struct {
initialAnswer string
refreshAnswer string
initialTTL time.Duration
refreshTTL time.Duration
- queryTime time.Time
+ readDelay time.Duration
answer string
ok bool
- awaitUpdate bool
+ refetch bool
}{
// Serves cached value before expiry
- {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(30 * time.Second), "192.0.2.1", true, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 30 * time.Second, "192.0.2.1", true, true},
// Serves stale cached value after expiry and before refresh happens
- {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.1", true, false},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 61 * time.Second, "192.0.2.1", true, false},
// Serves refreshed value after expiry and refresh
- {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, now.Add(61 * time.Second), "192.0.2.42", true, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, time.Minute, 61 * time.Second, "192.0.2.42", true, true},
// Refreshed value can no longer be cached
- {"192.0.2.1", "192.0.2.42", time.Minute, 0, now.Add(61 * time.Second), "192.0.2.42", false, true},
+ {"192.0.2.1", "192.0.2.42", time.Minute, 0, 61 * time.Second, "192.0.2.42", false, true},
}
for i, tt := range tests {
msg := newA("example.com.", uint32(tt.initialTTL.Seconds()), net.ParseIP(tt.initialAnswer))
@@ -317,30 +315,32 @@ func TestCachePrefetch(t *testing.T) {
copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds())
exchanger.reset()
exchanger.setAnswer(copy)
- c.now = func() time.Time { return now }
+ // Add new value now
+ c.mu.Lock()
+ c.now = func() time.Time { return now }
+ c.mu.Unlock()
var key uint32 = 1
c.Set(key, msg)
- c.now = func() time.Time { return tt.queryTime }
+
+ // Read value at some point in the future
+ c.mu.Lock()
+ c.now = func() time.Time { return now.Add(tt.readDelay) }
+ c.mu.Unlock()
v, ok := c.getValue(key)
+ c.Close() // Flush queued operations
- if tt.awaitUpdate {
- if !tt.ok {
- awaitExpiry(t, i, c, key)
- ok = false
- } else if c.isExpired(v) {
- awaitRefresh(t, i, c, key, v.CreatedAt)
- v, ok = c.getValue(key)
- }
+ if tt.refetch {
+ v, ok = c.getValue(key)
}
-
if ok != tt.ok {
t.Errorf("#%d: Get(%d) = (_, %t), want (_, %t)", i, key, ok, tt.ok)
}
-
- answers := dnsutil.Answers(v.msg)
- if tt.ok && answers[0] != tt.answer {
- t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer)
+ if tt.ok {
+ answers := dnsutil.Answers(v.msg)
+ if answers[0] != tt.answer {
+ t.Errorf("#%d: Get(%d) = (%q, _), want (%q, _)", i, key, answers[0], tt.answer)
+ }
}
}
}
@@ -371,7 +371,7 @@ func TestCacheEvictAndUpdate(t *testing.T) {
c.Get(key)
// Last query refreshes key
- awaitRefresh(t, 0, c, key, c.now().Add(-time.Second))
+ awaitRefresh(t, 0, c, key, c.now())
keyExists := false
for _, k := range c.keys {
if k == key {