aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-28 19:04:31 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-28 19:10:27 +0100
commit2b0fe090e54d0259d58a54e6a940b845960bbd98 (patch)
tree313d21383564e87d0a36e809c1e373ba35fd8476
parent5b8188dc1958064590db918b7cd0938139ad3b9b (diff)
Respect TTLs of additional section
-rw-r--r--cache/cache.go24
-rw-r--r--cache/cache_test.go70
2 files changed, 55 insertions, 39 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 167ff08..087ba24 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -17,7 +17,6 @@ type Cache struct {
mu sync.RWMutex
done chan bool
now func() time.Time
- interval time.Duration
}
// Value wraps a DNS message stored in the cache.
@@ -53,22 +52,23 @@ func (v *Value) Answers() []string {
return answers
}
-// TTL returns the TTL of the cached value v.
+// TTL returns the time to live of the cached value v.
func (v *Value) TTL() time.Duration { return minTTL(v.msg) }
// New creates a new cache of given capacity.
-func New(capacity int) *Cache {
+func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) }
+
+func newCache(capacity int, interval time.Duration, now func() time.Time) *Cache {
if capacity < 0 {
capacity = 0
}
cache := &Cache{
- now: time.Now,
+ now: now,
capacity: capacity,
values: make(map[uint64]*Value, capacity),
done: make(chan bool),
- interval: time.Minute,
}
- go maintain(cache)
+ go maintain(cache, interval)
return cache
}
@@ -81,8 +81,8 @@ func NewKey(name string, qtype, qclass uint16) uint64 {
return h.Sum64()
}
-func maintain(cache *Cache) {
- ticker := time.NewTicker(cache.interval)
+func maintain(cache *Cache, interval time.Duration) {
+ ticker := time.NewTicker(interval)
for {
select {
case <-cache.done:
@@ -195,12 +195,20 @@ func min(x, y uint32) uint32 {
func minTTL(m *dns.Msg) time.Duration {
var ttl uint32 = 1<<32 - 1 // avoid importing math
+ // Choose the lowest TTL of answer, authority and additional sections.
for _, answer := range m.Answer {
ttl = min(answer.Header().Ttl, ttl)
}
for _, ns := range m.Ns {
ttl = min(ns.Header().Ttl, ttl)
}
+ for _, extra := range m.Extra {
+ // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags
+ if extra.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ ttl = min(extra.Header().Ttl, ttl)
+ }
return time.Duration(ttl) * time.Second
}
diff --git a/cache/cache_test.go b/cache/cache_test.go
index df05721..77f9f15 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -10,12 +10,6 @@ import (
"github.com/miekg/dns"
)
-func handleErr(t *testing.T, fn func() error) {
- if err := fn(); err != nil {
- t.Fatal(err)
- }
-}
-
func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
m := dns.Msg{}
m.Id = dns.Id()
@@ -66,10 +60,11 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) {
now := time.Now()
for {
c.mu.RLock()
- if _, ok := c.values[k]; !ok {
+ _, ok := c.values[k]
+ c.mu.RUnlock()
+ if !ok {
break
}
- c.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
t.Fatalf("timed out waiting for expiry of key %d", k)
@@ -78,30 +73,45 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) {
}
func TestCache(t *testing.T) {
- msg := newA("foo.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2"))
- msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2"))
- msgFailure := newA("baz.", 60, net.ParseIP("192.0.2.2"))
+ msg := newA("r1.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2"))
+ msgWithZeroTTL := newA("r2.", 0, net.ParseIP("192.0.2.2"))
+ msgFailure := newA("r3.", 60, net.ParseIP("192.0.2.2"))
msgFailure.Rcode = dns.RcodeServerFailure
+ msgNameError := &dns.Msg{}
+ msgNameError.Id = dns.Id()
+ msgNameError.SetQuestion(dns.Fqdn("r4."), dns.TypeA)
+ msgNameError.Rcode = dns.RcodeNameError
+ msgLowerNsTTL := newA("r5.", 60, net.ParseIP("192.0.2.1"))
+ msgLowerNsTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r5.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 30}}}
+ msgLowerExtraTTL := newA("r6.", 3600, net.ParseIP("192.0.2.1"))
+ msgLowerExtraTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 60}}}
+ msgLowerExtraTTL.Extra = []dns.RR{
+ &dns.OPT{Hdr: dns.RR_Header{Name: "EDNS", Rrtype: dns.TypeOPT, Class: dns.ClassINET, Ttl: 10}}, // Ignored
+ &dns.A{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 30}},
+ }
- createdAt := date(2019, 1, 1)
- c := New(100)
- c.interval = 10 * time.Millisecond
- defer handleErr(t, c.Close)
+ now := date(2019, 1, 1)
+ nowFn := func() time.Time { return now }
+ c := newCache(100, 10*time.Millisecond, nowFn)
+ defer c.Close()
var tests = []struct {
- msg *dns.Msg
- createdAt, queriedAt time.Time
- ok bool
- value *Value
+ msg *dns.Msg
+ queriedAt time.Time
+ ok bool
+ value *Value
}{
- {msg, createdAt, createdAt, true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired when query time == create time
- {msg, createdAt, createdAt.Add(30 * time.Second), true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired when below TTL
- {msg, createdAt, createdAt.Add(60 * time.Second), true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired until TTL exceeds
- {msg, createdAt, createdAt.Add(61 * time.Second), false, nil}, // Expired
- {msgWithZeroTTL, createdAt, createdAt, false, nil}, // 0 TTL is not cached
- {msgFailure, createdAt, createdAt, false, nil}, // Non-cacheable rcode
+ {msg, now, true, &Value{CreatedAt: now, msg: msg}}, // Not expired when query time == create time
+ {msg, now.Add(30 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired when below TTL
+ {msg, now.Add(60 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds
+ {msgNameError, now, true, &Value{CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached
+ {msg, now.Add(61 * time.Second), false, nil}, // Expired due to answer TTL
+ {msgLowerNsTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower NS TTL
+ {msgLowerExtraTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower Extra TTL
+ {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached
+ {msgFailure, now, false, nil}, // Non-cacheable rcode
}
for i, tt := range tests {
- c.now = func() time.Time { return tt.createdAt }
+ c.now = nowFn
k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass)
c.Set(k, tt.msg)
c.now = func() time.Time { return tt.queriedAt }
@@ -125,7 +135,7 @@ func TestCache(t *testing.T) {
}
}
if (keyIdx != -1) != tt.ok {
- t.Errorf("#%d: keys[%d] = %d, should not exist", i, keyIdx, k)
+ t.Errorf("#%d: keys[%d] = %d, got expired key", i, keyIdx, k)
}
}
}
@@ -141,8 +151,7 @@ func TestCacheCapacity(t *testing.T) {
}
for i, tt := range tests {
c := New(tt.capacity)
- c.interval = 10 * time.Millisecond
- defer handleErr(t, c.Close)
+ defer c.Close()
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i)))
@@ -182,8 +191,7 @@ func TestCacheList(t *testing.T) {
}
for i, tt := range tests {
c := New(1024)
- c.interval = 10 * time.Millisecond
- defer handleErr(t, c.Close)
+ defer c.Close()
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i)))