diff options
Diffstat (limited to 'cache/cache.go')
-rw-r--r-- | cache/cache.go | 81 |
1 files changed, 63 insertions, 18 deletions
diff --git a/cache/cache.go b/cache/cache.go index a474170..3b650e9 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -12,6 +12,7 @@ import ( // Cache is a cache of DNS messages. type Cache struct { + client *dnsutil.Client capacity int values map[uint64]*Value keys []uint64 @@ -41,14 +42,18 @@ func (v *Value) Answers() []string { return dnsutil.Answers(v.msg) } // TTL returns the time to live of the cached value v. func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) } -// New creates a new cache of given capacity. -func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) } +// New creates a new cache of given capacity. If client is non-nil, the cache will prefetch expired entries in an effort +// to serve results faster. +func New(capacity int, client *dnsutil.Client) *Cache { + return newCache(capacity, client, 10*time.Second, time.Now) +} -func newCache(capacity int, interval time.Duration, now func() time.Time) *Cache { +func newCache(capacity int, client *dnsutil.Client, interval time.Duration, now func() time.Time) *Cache { if capacity < 0 { capacity = 0 } cache := &Cache{ + client: client, now: now, capacity: capacity, values: make(map[uint64]*Value, capacity), @@ -75,7 +80,11 @@ func maintain(cache *Cache, interval time.Duration) { ticker.Stop() return case <-ticker.C: - cache.evictExpired() + if cache.prefetch() { + cache.refreshExpired(interval) + } else { + cache.evictExpired() + } } } } @@ -100,7 +109,7 @@ func (c *Cache) getValue(k uint64) (*Value, bool) { c.mu.RLock() v, ok := c.values[k] c.mu.RUnlock() - if !ok || c.isExpired(v) { + if !ok || (!c.prefetch() && c.isExpired(v)) { return nil, false } return v, true @@ -152,34 +161,70 @@ func (c *Cache) Reset() { c.keys = nil } +func (c *Cache) prefetch() bool { return c.client != nil } + +func (c *Cache) refreshExpired(interval time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + evicted := make(map[uint64]bool) + for k, v := range c.values { + // Value will expiry before the next interval. Refresh now + if c.isExpiredAfter(interval, v) { + q := v.msg.Question[0] + msg := dns.Msg{} + msg.SetQuestion(q.Name, q.Qtype) + r, err := c.client.Exchange(&msg) + if err != nil { + continue // Will be retried on next run + } + if canCache(r) { + c.values[k].CreatedAt = c.now() + c.values[k].msg = r + } else { + // Can no longer be cached. Evict + delete(c.values, k) + evicted[k] = true + } + } + } + c.reorderKeys(evicted) +} + func (c *Cache) evictExpired() { c.mu.Lock() defer c.mu.Unlock() - evictedKeys := make(map[uint64]bool) + evicted := make(map[uint64]bool) for k, v := range c.values { if c.isExpired(v) { delete(c.values, k) - evictedKeys[k] = true + evicted[k] = true } } - if len(evictedKeys) > 0 { - // At least one entry was evicted. The ordered list of keys must be updated. - var keys []uint64 - for _, k := range c.keys { - if _, ok := evictedKeys[k]; ok { - continue - } - keys = append(keys, k) + c.reorderKeys(evicted) +} + +func (c *Cache) reorderKeys(evicted map[uint64]bool) { + if len(evicted) == 0 { + return + } + // At least one entry was evicted. The ordered list of keys must be updated. + var keys []uint64 + for _, k := range c.keys { + if _, ok := evicted[k]; ok { + continue } - c.keys = keys + keys = append(keys, k) } + c.keys = keys } -func (c *Cache) isExpired(v *Value) bool { +func (c *Cache) isExpiredAfter(d time.Duration, v *Value) bool { expiresAt := v.CreatedAt.Add(dnsutil.MinTTL(v.msg)) - return c.now().After(expiresAt) + return c.now().Add(d).After(expiresAt) } +func (c *Cache) isExpired(v *Value) bool { return c.isExpiredAfter(0, v) } + func canCache(msg *dns.Msg) bool { if dnsutil.MinTTL(msg) == 0 { return false |