aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-08-13 20:30:50 +0200
committerMartin Polden <mpolden@mpolden.no>2019-08-13 20:30:50 +0200
commit491a0e1cbe9a617ac3e29535c6e71ca82ebfe5a7 (patch)
tree25299940e48b44bed36d7719fabcb5fed451a7b3
parent7361fcc0fa6d2f7241b9f71c67657896f0199c4b (diff)
Do not cache 0 TTL
-rw-r--r--cache/cache.go33
-rw-r--r--cache/cache_test.go12
2 files changed, 39 insertions, 6 deletions
diff --git a/cache/cache.go b/cache/cache.go
index cfb80b5..281920a 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -97,11 +97,14 @@ func (c *Cache) Get(k uint32) (*dns.Msg, bool) {
}
// Set associated key k with the DNS message v. Message v will expire from the cache according to its TTL. Setting a
-// new key in a cache that has its maximum size will remove the first key.
+// new key in a cache that has reached its maximum size will remove the first key.
func (c *Cache) Set(k uint32, v *dns.Msg) {
if c.maxSize == 0 {
return
}
+ if !isCacheable(v) {
+ return
+ }
now := c.now()
c.mu.Lock()
if len(c.entries) == c.maxSize && c.maxSize > 0 {
@@ -133,6 +136,34 @@ func (c *Cache) isExpired(v *value) bool {
return false
}
+func min(x, y uint32) uint32 {
+ if x < y {
+ return x
+ }
+ return y
+}
+
+func minTTL(m *dns.Msg) time.Duration {
+ var ttl uint32 = 1<<32 - 1 // avoids importing math.MaxUint32
+ for _, answer := range m.Answer {
+ ttl = min(answer.Header().Ttl, ttl)
+ }
+ for _, ns := range m.Ns {
+ ttl = min(ns.Header().Ttl, ttl)
+ }
+ for _, extra := range m.Extra {
+ ttl = min(extra.Header().Ttl, ttl)
+ }
+ return time.Duration(ttl) * time.Second
+}
+
+func isCacheable(m *dns.Msg) bool {
+ if minTTL(m) == 0 {
+ return false
+ }
+ return true
+}
+
func ttl(rr dns.RR) time.Duration {
ttlSecs := rr.Header().Ttl
return time.Duration(ttlSecs) * time.Second
diff --git a/cache/cache_test.go b/cache/cache_test.go
index a81cef1..d607aa8 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -67,7 +67,8 @@ func awaitExpiry(t *testing.T, c *Cache, k uint32) {
}
func TestCache(t *testing.T) {
- m := newA("foo.", 60, net.ParseIP("192.0.2.1"))
+ msg := newA("foo.", 60, net.ParseIP("192.0.2.1"))
+ msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2"))
tt := date(2019, 1, 1)
c, err := New(100, time.Duration(10*time.Millisecond))
if err != nil {
@@ -79,10 +80,11 @@ func TestCache(t *testing.T) {
createdAt, queriedAt time.Time
ok bool
}{
- {m, tt, tt, true}, // Not expired when query time == create time
- {m, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL
- {m, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds
- {m, tt, tt.Add(61 * time.Second), false}, // Expired
+ {msg, tt, tt, true}, // Not expired when query time == create time
+ {msg, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL
+ {msg, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds
+ {msg, tt, tt.Add(61 * time.Second), false}, // Expired
+ {msgWithZeroTTL, tt, tt, false}, // 0 TTL is not cached
}
for i, tt := range tests {
c.now = func() time.Time { return tt.createdAt }