From 096996d806f4c3e7ad21b4a5960f55663dfbfba9 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Mon, 30 Dec 2019 18:01:44 +0100 Subject: Simplify cache eviction --- cache/cache.go | 55 ++++++++--------------------------------------------- cache/cache_test.go | 7 ++----- cmd/zdns/main.go | 1 - 3 files changed, 10 insertions(+), 53 deletions(-) diff --git a/cache/cache.go b/cache/cache.go index 48d47f6..0780a39 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -45,24 +45,20 @@ func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) } // New creates a new cache of given capacity. If client is non-nil, the cache will prefetch expired entries in an effort // to serve results faster. func New(capacity int, client *dnsutil.Client) *Cache { - return newCache(capacity, client, 10*time.Second, time.Now) + return newCache(capacity, client, time.Now) } -func newCache(capacity int, client *dnsutil.Client, interval time.Duration, now func() time.Time) *Cache { +func newCache(capacity int, client *dnsutil.Client, now func() time.Time) *Cache { if capacity < 0 { capacity = 0 } - c := &Cache{ + return &Cache{ client: client, now: now, capacity: capacity, values: make(map[uint64]*Value, capacity), done: make(chan bool), } - if !c.prefetch() { - go maintain(c, interval) - } - return c } // NewKey creates a new cache key for the DNS name, qtype and qclass @@ -74,25 +70,6 @@ func NewKey(name string, qtype, qclass uint16) uint64 { return h.Sum64() } -func maintain(cache *Cache, interval time.Duration) { - ticker := time.NewTicker(interval) - for { - select { - case <-cache.done: - ticker.Stop() - return - case <-ticker.C: - cache.evictExpired() - } - } -} - -// Close stops any outstanding maintenance tasks. -func (c *Cache) Close() error { - c.done <- true - return nil -} - // Get returns the DNS message associated with key k. Get will return nil if any TTL in the answer section of the // message is exceeded according to time t. func (c *Cache) Get(k uint64) (*dns.Msg, bool) { @@ -112,6 +89,7 @@ func (c *Cache) getValue(k uint64) (*Value, bool) { } if c.isExpired(v) { if !c.prefetch() { + go c.evict(k) return nil, false } // Refresh and return a stale value @@ -169,7 +147,6 @@ func (c *Cache) Reset() { func (c *Cache) prefetch() bool { return c.client != nil } func (c *Cache) refresh(key uint64, old *dns.Msg) { - evicted := make(map[uint64]bool, 1) q := old.Question[0] msg := dns.Msg{} msg.SetQuestion(q.Name, q.Qtype) @@ -183,33 +160,17 @@ func (c *Cache) refresh(key uint64, old *dns.Msg) { c.values[key].CreatedAt = c.now() c.values[key].msg = r } else { - delete(c.values, key) - evicted[key] = true + c.evict(key) } - c.reorderKeys(evicted) } -func (c *Cache) evictExpired() { +func (c *Cache) evict(key uint64) { c.mu.Lock() defer c.mu.Unlock() - evicted := make(map[uint64]bool) - for k, v := range c.values { - if c.isExpired(v) { - delete(c.values, k) - evicted[k] = true - } - } - c.reorderKeys(evicted) -} - -func (c *Cache) reorderKeys(evicted map[uint64]bool) { - if len(evicted) == 0 { - return - } - // At least one entry was evicted. The ordered list of keys must be updated. + delete(c.values, key) var keys []uint64 for _, k := range c.keys { - if _, ok := evicted[k]; ok { + if k == key { continue } keys = append(keys, k) diff --git a/cache/cache_test.go b/cache/cache_test.go index e24e3b2..1ccfaaf 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -89,8 +89,7 @@ func TestCache(t *testing.T) { now := time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC) nowFn := func() time.Time { return now } - c := newCache(100, nil, 10*time.Millisecond, nowFn) - defer c.Close() + c := newCache(100, nil, nowFn) var tests = []struct { msg *dns.Msg queriedAt time.Time @@ -146,7 +145,6 @@ func TestCacheCapacity(t *testing.T) { } for i, tt := range tests { c := New(tt.capacity, nil) - defer c.Close() var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) @@ -186,7 +184,6 @@ func TestCacheList(t *testing.T) { } for i, tt := range tests { c := New(1024, nil) - defer c.Close() var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) @@ -230,7 +227,7 @@ func TestCachePrefetch(t *testing.T) { client := &dnsutil.Client{Exchanger: &exchanger, Addresses: []string{"resolver"}} now := time.Now() nowFn := func() time.Time { return now } - c := newCache(10, client, time.Hour, nowFn) + c := newCache(10, client, nowFn) var key uint64 = 1 ip := net.ParseIP("192.0.2.1") diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 03fdb26..cda841d 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -99,7 +99,6 @@ func (c *cli) run() { cclient = client } cache := cache.New(config.DNS.CacheSize, cclient) - sigHandler.OnClose(cache) // DNS server proxy, err := dns.NewProxy(cache, client, logger) -- cgit v1.2.3