diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-08-13 20:30:50 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-08-13 20:30:50 +0200 |
commit | 491a0e1cbe9a617ac3e29535c6e71ca82ebfe5a7 (patch) | |
tree | 25299940e48b44bed36d7719fabcb5fed451a7b3 | |
parent | 7361fcc0fa6d2f7241b9f71c67657896f0199c4b (diff) |
Do not cache 0 TTL
-rw-r--r-- | cache/cache.go | 33 | ||||
-rw-r--r-- | cache/cache_test.go | 12 |
2 files changed, 39 insertions, 6 deletions
diff --git a/cache/cache.go b/cache/cache.go index cfb80b5..281920a 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -97,11 +97,14 @@ func (c *Cache) Get(k uint32) (*dns.Msg, bool) { } // Set associated key k with the DNS message v. Message v will expire from the cache according to its TTL. Setting a -// new key in a cache that has its maximum size will remove the first key. +// new key in a cache that has reached its maximum size will remove the first key. func (c *Cache) Set(k uint32, v *dns.Msg) { if c.maxSize == 0 { return } + if !isCacheable(v) { + return + } now := c.now() c.mu.Lock() if len(c.entries) == c.maxSize && c.maxSize > 0 { @@ -133,6 +136,34 @@ func (c *Cache) isExpired(v *value) bool { return false } +func min(x, y uint32) uint32 { + if x < y { + return x + } + return y +} + +func minTTL(m *dns.Msg) time.Duration { + var ttl uint32 = 1<<32 - 1 // avoids importing math.MaxUint32 + for _, answer := range m.Answer { + ttl = min(answer.Header().Ttl, ttl) + } + for _, ns := range m.Ns { + ttl = min(ns.Header().Ttl, ttl) + } + for _, extra := range m.Extra { + ttl = min(extra.Header().Ttl, ttl) + } + return time.Duration(ttl) * time.Second +} + +func isCacheable(m *dns.Msg) bool { + if minTTL(m) == 0 { + return false + } + return true +} + func ttl(rr dns.RR) time.Duration { ttlSecs := rr.Header().Ttl return time.Duration(ttlSecs) * time.Second diff --git a/cache/cache_test.go b/cache/cache_test.go index a81cef1..d607aa8 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -67,7 +67,8 @@ func awaitExpiry(t *testing.T, c *Cache, k uint32) { } func TestCache(t *testing.T) { - m := newA("foo.", 60, net.ParseIP("192.0.2.1")) + msg := newA("foo.", 60, net.ParseIP("192.0.2.1")) + msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2")) tt := date(2019, 1, 1) c, err := New(100, time.Duration(10*time.Millisecond)) if err != nil { @@ -79,10 +80,11 @@ func TestCache(t *testing.T) { createdAt, queriedAt time.Time ok bool }{ - {m, tt, tt, true}, // Not expired when query time == create time - {m, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL - {m, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds - {m, tt, tt.Add(61 * time.Second), false}, // Expired + {msg, tt, tt, true}, // Not expired when query time == create time + {msg, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL + {msg, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds + {msg, tt, tt.Add(61 * time.Second), false}, // Expired + {msgWithZeroTTL, tt, tt, false}, // 0 TTL is not cached } for i, tt := range tests { c.now = func() time.Time { return tt.createdAt } |