diff options
-rw-r--r-- | cache/cache.go | 24 | ||||
-rw-r--r-- | cache/cache_test.go | 70 |
2 files changed, 55 insertions, 39 deletions
diff --git a/cache/cache.go b/cache/cache.go index 167ff08..087ba24 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -17,7 +17,6 @@ type Cache struct { mu sync.RWMutex done chan bool now func() time.Time - interval time.Duration } // Value wraps a DNS message stored in the cache. @@ -53,22 +52,23 @@ func (v *Value) Answers() []string { return answers } -// TTL returns the TTL of the cached value v. +// TTL returns the time to live of the cached value v. func (v *Value) TTL() time.Duration { return minTTL(v.msg) } // New creates a new cache of given capacity. -func New(capacity int) *Cache { +func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) } + +func newCache(capacity int, interval time.Duration, now func() time.Time) *Cache { if capacity < 0 { capacity = 0 } cache := &Cache{ - now: time.Now, + now: now, capacity: capacity, values: make(map[uint64]*Value, capacity), done: make(chan bool), - interval: time.Minute, } - go maintain(cache) + go maintain(cache, interval) return cache } @@ -81,8 +81,8 @@ func NewKey(name string, qtype, qclass uint16) uint64 { return h.Sum64() } -func maintain(cache *Cache) { - ticker := time.NewTicker(cache.interval) +func maintain(cache *Cache, interval time.Duration) { + ticker := time.NewTicker(interval) for { select { case <-cache.done: @@ -195,12 +195,20 @@ func min(x, y uint32) uint32 { func minTTL(m *dns.Msg) time.Duration { var ttl uint32 = 1<<32 - 1 // avoid importing math + // Choose the lowest TTL of answer, authority and additional sections. for _, answer := range m.Answer { ttl = min(answer.Header().Ttl, ttl) } for _, ns := range m.Ns { ttl = min(ns.Header().Ttl, ttl) } + for _, extra := range m.Extra { + // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags + if extra.Header().Rrtype == dns.TypeOPT { + continue + } + ttl = min(extra.Header().Ttl, ttl) + } return time.Duration(ttl) * time.Second } diff --git a/cache/cache_test.go b/cache/cache_test.go index df05721..77f9f15 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -10,12 +10,6 @@ import ( "github.com/miekg/dns" ) -func handleErr(t *testing.T, fn func() error) { - if err := fn(); err != nil { - t.Fatal(err) - } -} - func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { m := dns.Msg{} m.Id = dns.Id() @@ -66,10 +60,11 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) { now := time.Now() for { c.mu.RLock() - if _, ok := c.values[k]; !ok { + _, ok := c.values[k] + c.mu.RUnlock() + if !ok { break } - c.mu.RUnlock() time.Sleep(10 * time.Millisecond) if time.Since(now) > 2*time.Second { t.Fatalf("timed out waiting for expiry of key %d", k) @@ -78,30 +73,45 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) { } func TestCache(t *testing.T) { - msg := newA("foo.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2")) - msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2")) - msgFailure := newA("baz.", 60, net.ParseIP("192.0.2.2")) + msg := newA("r1.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2")) + msgWithZeroTTL := newA("r2.", 0, net.ParseIP("192.0.2.2")) + msgFailure := newA("r3.", 60, net.ParseIP("192.0.2.2")) msgFailure.Rcode = dns.RcodeServerFailure + msgNameError := &dns.Msg{} + msgNameError.Id = dns.Id() + msgNameError.SetQuestion(dns.Fqdn("r4."), dns.TypeA) + msgNameError.Rcode = dns.RcodeNameError + msgLowerNsTTL := newA("r5.", 60, net.ParseIP("192.0.2.1")) + msgLowerNsTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r5.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 30}}} + msgLowerExtraTTL := newA("r6.", 3600, net.ParseIP("192.0.2.1")) + msgLowerExtraTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 60}}} + msgLowerExtraTTL.Extra = []dns.RR{ + &dns.OPT{Hdr: dns.RR_Header{Name: "EDNS", Rrtype: dns.TypeOPT, Class: dns.ClassINET, Ttl: 10}}, // Ignored + &dns.A{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 30}}, + } - createdAt := date(2019, 1, 1) - c := New(100) - c.interval = 10 * time.Millisecond - defer handleErr(t, c.Close) + now := date(2019, 1, 1) + nowFn := func() time.Time { return now } + c := newCache(100, 10*time.Millisecond, nowFn) + defer c.Close() var tests = []struct { - msg *dns.Msg - createdAt, queriedAt time.Time - ok bool - value *Value + msg *dns.Msg + queriedAt time.Time + ok bool + value *Value }{ - {msg, createdAt, createdAt, true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired when query time == create time - {msg, createdAt, createdAt.Add(30 * time.Second), true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired when below TTL - {msg, createdAt, createdAt.Add(60 * time.Second), true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired until TTL exceeds - {msg, createdAt, createdAt.Add(61 * time.Second), false, nil}, // Expired - {msgWithZeroTTL, createdAt, createdAt, false, nil}, // 0 TTL is not cached - {msgFailure, createdAt, createdAt, false, nil}, // Non-cacheable rcode + {msg, now, true, &Value{CreatedAt: now, msg: msg}}, // Not expired when query time == create time + {msg, now.Add(30 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired when below TTL + {msg, now.Add(60 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds + {msgNameError, now, true, &Value{CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached + {msg, now.Add(61 * time.Second), false, nil}, // Expired due to answer TTL + {msgLowerNsTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower NS TTL + {msgLowerExtraTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower Extra TTL + {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached + {msgFailure, now, false, nil}, // Non-cacheable rcode } for i, tt := range tests { - c.now = func() time.Time { return tt.createdAt } + c.now = nowFn k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass) c.Set(k, tt.msg) c.now = func() time.Time { return tt.queriedAt } @@ -125,7 +135,7 @@ func TestCache(t *testing.T) { } } if (keyIdx != -1) != tt.ok { - t.Errorf("#%d: keys[%d] = %d, should not exist", i, keyIdx, k) + t.Errorf("#%d: keys[%d] = %d, got expired key", i, keyIdx, k) } } } @@ -141,8 +151,7 @@ func TestCacheCapacity(t *testing.T) { } for i, tt := range tests { c := New(tt.capacity) - c.interval = 10 * time.Millisecond - defer handleErr(t, c.Close) + defer c.Close() var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) @@ -182,8 +191,7 @@ func TestCacheList(t *testing.T) { } for i, tt := range tests { c := New(1024) - c.interval = 10 * time.Millisecond - defer handleErr(t, c.Close) + defer c.Close() var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) |