diff options
-rw-r--r-- | cache/cache.go | 50 | ||||
-rw-r--r-- | cache/cache_test.go | 12 | ||||
-rw-r--r-- | dns/proxy.go | 34 | ||||
-rw-r--r-- | dnsutil/dnsutil.go | 63 | ||||
-rw-r--r-- | dnsutil/dnsutil_test.go | 57 | ||||
-rw-r--r-- | http/http.go | 8 |
6 files changed, 132 insertions, 92 deletions
diff --git a/cache/cache.go b/cache/cache.go index b6c03dc..f3bb337 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -7,6 +7,7 @@ import ( "time" "github.com/miekg/dns" + "github.com/mpolden/zdns/dnsutil" ) // Cache is a cache of DNS messages. @@ -35,25 +36,10 @@ func (v *Value) Question() string { return v.msg.Question[0].Name } func (v *Value) Qtype() uint16 { return v.msg.Question[0].Qtype } // Answers returns the answers of the cached value v. -func (v *Value) Answers() []string { - var answers []string - for _, answer := range v.msg.Answer { - switch v := answer.(type) { - case *dns.A: - answers = append(answers, v.A.String()) - case *dns.AAAA: - answers = append(answers, v.AAAA.String()) - case *dns.MX: - answers = append(answers, v.Mx) - case *dns.PTR: - answers = append(answers, v.Ptr) - } - } - return answers -} +func (v *Value) Answers() []string { return dnsutil.Answers(v.msg) } // TTL returns the time to live of the cached value v. -func (v *Value) TTL() time.Duration { return minTTL(v.msg) } +func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) } // New creates a new cache of given capacity. func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) } @@ -190,38 +176,12 @@ func (c *Cache) evictExpired() { } func (c *Cache) isExpired(v *Value) bool { - expiresAt := v.CreatedAt.Add(minTTL(v.msg)) + expiresAt := v.CreatedAt.Add(dnsutil.MinTTL(v.msg)) return c.now().After(expiresAt) } -func min(x, y uint32) uint32 { - if x < y { - return x - } - return y -} - -func minTTL(m *dns.Msg) time.Duration { - var ttl uint32 = 1<<32 - 1 // avoid importing math - // Choose the lowest TTL of answer, authority and additional sections. - for _, answer := range m.Answer { - ttl = min(answer.Header().Ttl, ttl) - } - for _, ns := range m.Ns { - ttl = min(ns.Header().Ttl, ttl) - } - for _, extra := range m.Extra { - // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags - if extra.Header().Rrtype == dns.TypeOPT { - continue - } - ttl = min(extra.Header().Ttl, ttl) - } - return time.Duration(ttl) * time.Second -} - func canCache(m *dns.Msg) bool { - if minTTL(m) == 0 { + if dnsutil.MinTTL(m) == 0 { return false } return m.Rcode == dns.RcodeSuccess || m.Rcode == dns.RcodeNameError diff --git a/cache/cache_test.go b/cache/cache_test.go index 33f3a6b..b9ede8a 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -81,14 +81,6 @@ func TestCache(t *testing.T) { msgNameError.Id = dns.Id() msgNameError.SetQuestion(dns.Fqdn("r4."), dns.TypeA) msgNameError.Rcode = dns.RcodeNameError - msgLowerNsTTL := newA("r5.", 60, net.ParseIP("192.0.2.1")) - msgLowerNsTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r5.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 30}}} - msgLowerExtraTTL := newA("r6.", 3600, net.ParseIP("192.0.2.1")) - msgLowerExtraTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 60}}} - msgLowerExtraTTL.Extra = []dns.RR{ - &dns.OPT{Hdr: dns.RR_Header{Name: "EDNS", Rrtype: dns.TypeOPT, Class: dns.ClassINET, Ttl: 10}}, // Ignored - &dns.A{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 30}}, - } now := date(2019, 1, 1) nowFn := func() time.Time { return now } @@ -104,9 +96,7 @@ func TestCache(t *testing.T) { {msg, now.Add(30 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired when below TTL {msg, now.Add(60 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds {msgNameError, now, true, &Value{CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached - {msg, now.Add(61 * time.Second), false, nil}, // Expired due to answer TTL - {msgLowerNsTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower NS TTL - {msgLowerExtraTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower Extra TTL + {msg, now.Add(61 * time.Second), false, nil}, // Expired due to TTL exceeded {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached {msgFailure, now, false, nil}, // Non-cacheable rcode } diff --git a/dns/proxy.go b/dns/proxy.go index dc325a9..e5e5472 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -9,6 +9,7 @@ import ( "github.com/miekg/dns" "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns/http" + "github.com/mpolden/zdns/dnsutil" ) const ( @@ -24,14 +25,6 @@ const ( LogHijacked ) -var ( - // TypeToString contains a mapping of DNS request type to string. - TypeToString = dns.TypeToString - - // RcodeToString contains a mapping of Mapping DNS return code to string. - RcodeToString = dns.RcodeToString -) - // Request represents a simplified DNS request. type Request struct { Type uint16 @@ -151,28 +144,6 @@ func (p *Proxy) Close() error { return nil } -func answers(msg *dns.Msg) []string { - var answers []string - for _, answer := range msg.Answer { - // Log answers for the following DNS types. Other types are still logged, but their answers are not. - switch v := answer.(type) { - case *dns.A: - answers = append(answers, v.A.String()) - case *dns.AAAA: - answers = append(answers, v.AAAA.String()) - case *dns.MX: - answers = append(answers, v.Mx) - case *dns.PTR: - answers = append(answers, v.Ptr) - case *dns.NS: - answers = append(answers, v.Ns) - case *dns.CNAME: - answers = append(answers, v.Target) - } - } - return answers -} - func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) { if p.logMode == LogAll || (hijacked && p.logMode == LogHijacked) { var ip net.IP @@ -184,8 +155,7 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) { default: panic(fmt.Sprintf("unexpected remote address type %T", v)) } - answers := answers(msg) - p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, answers...) + p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...) } w.WriteMsg(msg) } diff --git a/dnsutil/dnsutil.go b/dnsutil/dnsutil.go new file mode 100644 index 0000000..6f791ef --- /dev/null +++ b/dnsutil/dnsutil.go @@ -0,0 +1,63 @@ +package dnsutil + +import ( + "time" + + "github.com/miekg/dns" +) + +var ( + // TypeToString contains a mapping of DNS request type to string. + TypeToString = dns.TypeToString + + // RcodeToString contains a mapping of Mapping DNS response code to string. + RcodeToString = dns.RcodeToString +) + +// Answers returns all values in the answer section of DNS message msg. +func Answers(msg *dns.Msg) []string { + var answers []string + for _, answer := range msg.Answer { + switch v := answer.(type) { + case *dns.A: + answers = append(answers, v.A.String()) + case *dns.AAAA: + answers = append(answers, v.AAAA.String()) + case *dns.MX: + answers = append(answers, v.Mx) + case *dns.PTR: + answers = append(answers, v.Ptr) + case *dns.NS: + answers = append(answers, v.Ns) + case *dns.CNAME: + answers = append(answers, v.Target) + } + } + return answers +} + +// MinTTL returns the lowest TTL of of answer, authority and additional sections. +func MinTTL(msg *dns.Msg) time.Duration { + var ttl uint32 = (1 << 31) - 1 // Maximum TTL from RFC 2181 + for _, answer := range msg.Answer { + ttl = min(answer.Header().Ttl, ttl) + } + for _, ns := range msg.Ns { + ttl = min(ns.Header().Ttl, ttl) + } + for _, extra := range msg.Extra { + // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags + if extra.Header().Rrtype == dns.TypeOPT { + continue + } + ttl = min(extra.Header().Ttl, ttl) + } + return time.Duration(ttl) * time.Second +} + +func min(x, y uint32) uint32 { + if x < y { + return x + } + return y +} diff --git a/dnsutil/dnsutil_test.go b/dnsutil/dnsutil_test.go new file mode 100644 index 0000000..305485f --- /dev/null +++ b/dnsutil/dnsutil_test.go @@ -0,0 +1,57 @@ +package dnsutil + +import ( + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestMinTTL(t *testing.T) { + var tests = []struct { + answer []dns.RR + ns []dns.RR + extra []dns.RR + ttl time.Duration + }{ + { + []dns.RR{ + &dns.A{Hdr: dns.RR_Header{Ttl: 3600}}, + &dns.A{Hdr: dns.RR_Header{Ttl: 60}}, + }, + nil, + nil, + time.Minute, + }, + { + []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}}, + []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 30}}}, + nil, + 30 * time.Second, + }, + { + []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}}, + []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 30}}}, + []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 10}}}, + 10 * time.Second, + }, + { + []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}}, + nil, + []dns.RR{ + &dns.OPT{Hdr: dns.RR_Header{Ttl: 10, Rrtype: dns.TypeOPT}}, // Ignored + &dns.A{Hdr: dns.RR_Header{Ttl: 30}}, + }, + 30 * time.Second, + }, + } + for i, tt := range tests { + msg := dns.Msg{} + msg.Answer = tt.answer + msg.Ns = tt.ns + msg.Extra = tt.extra + if got := MinTTL(&msg); got != tt.ttl { + t.Errorf("#%d: MinTTL(\n%s) = %s, want %s", i, msg.String(), got, tt.ttl) + } + } +} diff --git a/http/http.go b/http/http.go index d4c3b42..36496cd 100644 --- a/http/http.go +++ b/http/http.go @@ -8,7 +8,7 @@ import ( "time" "github.com/mpolden/zdns/cache" - "github.com/mpolden/zdns/dns" + "github.com/mpolden/zdns/dnsutil" "github.com/mpolden/zdns/log" ) @@ -74,10 +74,10 @@ func (s *Server) cacheHandler(w http.ResponseWriter, r *http.Request) (interface entries = append(entries, entry{ Time: v.CreatedAt.UTC().Format(time.RFC3339), TTL: int64(v.TTL().Truncate(time.Second).Seconds()), - Qtype: dns.TypeToString[v.Qtype()], + Qtype: dnsutil.TypeToString[v.Qtype()], Question: v.Question(), Answers: v.Answers(), - Rcode: dns.RcodeToString[v.Rcode()], + Rcode: dnsutil.RcodeToString[v.Rcode()], }) } return entries, nil @@ -107,7 +107,7 @@ func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{} Time: le.Time.UTC().Format(time.RFC3339), RemoteAddr: le.RemoteAddr, Hijacked: &hijacked, - Qtype: dns.TypeToString[le.Qtype], + Qtype: dnsutil.TypeToString[le.Qtype], Question: le.Question, Answers: le.Answers, }) |