aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cache/cache.go50
-rw-r--r--cache/cache_test.go12
-rw-r--r--dns/proxy.go34
-rw-r--r--dnsutil/dnsutil.go63
-rw-r--r--dnsutil/dnsutil_test.go57
-rw-r--r--http/http.go8
6 files changed, 132 insertions, 92 deletions
diff --git a/cache/cache.go b/cache/cache.go
index b6c03dc..f3bb337 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/miekg/dns"
+ "github.com/mpolden/zdns/dnsutil"
)
// Cache is a cache of DNS messages.
@@ -35,25 +36,10 @@ func (v *Value) Question() string { return v.msg.Question[0].Name }
func (v *Value) Qtype() uint16 { return v.msg.Question[0].Qtype }
// Answers returns the answers of the cached value v.
-func (v *Value) Answers() []string {
- var answers []string
- for _, answer := range v.msg.Answer {
- switch v := answer.(type) {
- case *dns.A:
- answers = append(answers, v.A.String())
- case *dns.AAAA:
- answers = append(answers, v.AAAA.String())
- case *dns.MX:
- answers = append(answers, v.Mx)
- case *dns.PTR:
- answers = append(answers, v.Ptr)
- }
- }
- return answers
-}
+func (v *Value) Answers() []string { return dnsutil.Answers(v.msg) }
// TTL returns the time to live of the cached value v.
-func (v *Value) TTL() time.Duration { return minTTL(v.msg) }
+func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) }
// New creates a new cache of given capacity.
func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) }
@@ -190,38 +176,12 @@ func (c *Cache) evictExpired() {
}
func (c *Cache) isExpired(v *Value) bool {
- expiresAt := v.CreatedAt.Add(minTTL(v.msg))
+ expiresAt := v.CreatedAt.Add(dnsutil.MinTTL(v.msg))
return c.now().After(expiresAt)
}
-func min(x, y uint32) uint32 {
- if x < y {
- return x
- }
- return y
-}
-
-func minTTL(m *dns.Msg) time.Duration {
- var ttl uint32 = 1<<32 - 1 // avoid importing math
- // Choose the lowest TTL of answer, authority and additional sections.
- for _, answer := range m.Answer {
- ttl = min(answer.Header().Ttl, ttl)
- }
- for _, ns := range m.Ns {
- ttl = min(ns.Header().Ttl, ttl)
- }
- for _, extra := range m.Extra {
- // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags
- if extra.Header().Rrtype == dns.TypeOPT {
- continue
- }
- ttl = min(extra.Header().Ttl, ttl)
- }
- return time.Duration(ttl) * time.Second
-}
-
func canCache(m *dns.Msg) bool {
- if minTTL(m) == 0 {
+ if dnsutil.MinTTL(m) == 0 {
return false
}
return m.Rcode == dns.RcodeSuccess || m.Rcode == dns.RcodeNameError
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 33f3a6b..b9ede8a 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -81,14 +81,6 @@ func TestCache(t *testing.T) {
msgNameError.Id = dns.Id()
msgNameError.SetQuestion(dns.Fqdn("r4."), dns.TypeA)
msgNameError.Rcode = dns.RcodeNameError
- msgLowerNsTTL := newA("r5.", 60, net.ParseIP("192.0.2.1"))
- msgLowerNsTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r5.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 30}}}
- msgLowerExtraTTL := newA("r6.", 3600, net.ParseIP("192.0.2.1"))
- msgLowerExtraTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 60}}}
- msgLowerExtraTTL.Extra = []dns.RR{
- &dns.OPT{Hdr: dns.RR_Header{Name: "EDNS", Rrtype: dns.TypeOPT, Class: dns.ClassINET, Ttl: 10}}, // Ignored
- &dns.A{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 30}},
- }
now := date(2019, 1, 1)
nowFn := func() time.Time { return now }
@@ -104,9 +96,7 @@ func TestCache(t *testing.T) {
{msg, now.Add(30 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired when below TTL
{msg, now.Add(60 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds
{msgNameError, now, true, &Value{CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached
- {msg, now.Add(61 * time.Second), false, nil}, // Expired due to answer TTL
- {msgLowerNsTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower NS TTL
- {msgLowerExtraTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower Extra TTL
+ {msg, now.Add(61 * time.Second), false, nil}, // Expired due to TTL exceeded
{msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached
{msgFailure, now, false, nil}, // Non-cacheable rcode
}
diff --git a/dns/proxy.go b/dns/proxy.go
index dc325a9..e5e5472 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -9,6 +9,7 @@ import (
"github.com/miekg/dns"
"github.com/mpolden/zdns/cache"
"github.com/mpolden/zdns/dns/http"
+ "github.com/mpolden/zdns/dnsutil"
)
const (
@@ -24,14 +25,6 @@ const (
LogHijacked
)
-var (
- // TypeToString contains a mapping of DNS request type to string.
- TypeToString = dns.TypeToString
-
- // RcodeToString contains a mapping of Mapping DNS return code to string.
- RcodeToString = dns.RcodeToString
-)
-
// Request represents a simplified DNS request.
type Request struct {
Type uint16
@@ -151,28 +144,6 @@ func (p *Proxy) Close() error {
return nil
}
-func answers(msg *dns.Msg) []string {
- var answers []string
- for _, answer := range msg.Answer {
- // Log answers for the following DNS types. Other types are still logged, but their answers are not.
- switch v := answer.(type) {
- case *dns.A:
- answers = append(answers, v.A.String())
- case *dns.AAAA:
- answers = append(answers, v.AAAA.String())
- case *dns.MX:
- answers = append(answers, v.Mx)
- case *dns.PTR:
- answers = append(answers, v.Ptr)
- case *dns.NS:
- answers = append(answers, v.Ns)
- case *dns.CNAME:
- answers = append(answers, v.Target)
- }
- }
- return answers
-}
-
func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) {
if p.logMode == LogAll || (hijacked && p.logMode == LogHijacked) {
var ip net.IP
@@ -184,8 +155,7 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) {
default:
panic(fmt.Sprintf("unexpected remote address type %T", v))
}
- answers := answers(msg)
- p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, answers...)
+ p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...)
}
w.WriteMsg(msg)
}
diff --git a/dnsutil/dnsutil.go b/dnsutil/dnsutil.go
new file mode 100644
index 0000000..6f791ef
--- /dev/null
+++ b/dnsutil/dnsutil.go
@@ -0,0 +1,63 @@
+package dnsutil
+
+import (
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+var (
+ // TypeToString contains a mapping of DNS request type to string.
+ TypeToString = dns.TypeToString
+
+ // RcodeToString contains a mapping of Mapping DNS response code to string.
+ RcodeToString = dns.RcodeToString
+)
+
+// Answers returns all values in the answer section of DNS message msg.
+func Answers(msg *dns.Msg) []string {
+ var answers []string
+ for _, answer := range msg.Answer {
+ switch v := answer.(type) {
+ case *dns.A:
+ answers = append(answers, v.A.String())
+ case *dns.AAAA:
+ answers = append(answers, v.AAAA.String())
+ case *dns.MX:
+ answers = append(answers, v.Mx)
+ case *dns.PTR:
+ answers = append(answers, v.Ptr)
+ case *dns.NS:
+ answers = append(answers, v.Ns)
+ case *dns.CNAME:
+ answers = append(answers, v.Target)
+ }
+ }
+ return answers
+}
+
+// MinTTL returns the lowest TTL of of answer, authority and additional sections.
+func MinTTL(msg *dns.Msg) time.Duration {
+ var ttl uint32 = (1 << 31) - 1 // Maximum TTL from RFC 2181
+ for _, answer := range msg.Answer {
+ ttl = min(answer.Header().Ttl, ttl)
+ }
+ for _, ns := range msg.Ns {
+ ttl = min(ns.Header().Ttl, ttl)
+ }
+ for _, extra := range msg.Extra {
+ // OPT (EDNS) is a pseudo record which uses TTL field for extended RCODE and flags
+ if extra.Header().Rrtype == dns.TypeOPT {
+ continue
+ }
+ ttl = min(extra.Header().Ttl, ttl)
+ }
+ return time.Duration(ttl) * time.Second
+}
+
+func min(x, y uint32) uint32 {
+ if x < y {
+ return x
+ }
+ return y
+}
diff --git a/dnsutil/dnsutil_test.go b/dnsutil/dnsutil_test.go
new file mode 100644
index 0000000..305485f
--- /dev/null
+++ b/dnsutil/dnsutil_test.go
@@ -0,0 +1,57 @@
+package dnsutil
+
+import (
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+func TestMinTTL(t *testing.T) {
+ var tests = []struct {
+ answer []dns.RR
+ ns []dns.RR
+ extra []dns.RR
+ ttl time.Duration
+ }{
+ {
+ []dns.RR{
+ &dns.A{Hdr: dns.RR_Header{Ttl: 3600}},
+ &dns.A{Hdr: dns.RR_Header{Ttl: 60}},
+ },
+ nil,
+ nil,
+ time.Minute,
+ },
+ {
+ []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}},
+ []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 30}}},
+ nil,
+ 30 * time.Second,
+ },
+ {
+ []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}},
+ []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 30}}},
+ []dns.RR{&dns.NS{Hdr: dns.RR_Header{Ttl: 10}}},
+ 10 * time.Second,
+ },
+ {
+ []dns.RR{&dns.A{Hdr: dns.RR_Header{Ttl: 60}}},
+ nil,
+ []dns.RR{
+ &dns.OPT{Hdr: dns.RR_Header{Ttl: 10, Rrtype: dns.TypeOPT}}, // Ignored
+ &dns.A{Hdr: dns.RR_Header{Ttl: 30}},
+ },
+ 30 * time.Second,
+ },
+ }
+ for i, tt := range tests {
+ msg := dns.Msg{}
+ msg.Answer = tt.answer
+ msg.Ns = tt.ns
+ msg.Extra = tt.extra
+ if got := MinTTL(&msg); got != tt.ttl {
+ t.Errorf("#%d: MinTTL(\n%s) = %s, want %s", i, msg.String(), got, tt.ttl)
+ }
+ }
+}
diff --git a/http/http.go b/http/http.go
index d4c3b42..36496cd 100644
--- a/http/http.go
+++ b/http/http.go
@@ -8,7 +8,7 @@ import (
"time"
"github.com/mpolden/zdns/cache"
- "github.com/mpolden/zdns/dns"
+ "github.com/mpolden/zdns/dnsutil"
"github.com/mpolden/zdns/log"
)
@@ -74,10 +74,10 @@ func (s *Server) cacheHandler(w http.ResponseWriter, r *http.Request) (interface
entries = append(entries, entry{
Time: v.CreatedAt.UTC().Format(time.RFC3339),
TTL: int64(v.TTL().Truncate(time.Second).Seconds()),
- Qtype: dns.TypeToString[v.Qtype()],
+ Qtype: dnsutil.TypeToString[v.Qtype()],
Question: v.Question(),
Answers: v.Answers(),
- Rcode: dns.RcodeToString[v.Rcode()],
+ Rcode: dnsutil.RcodeToString[v.Rcode()],
})
}
return entries, nil
@@ -107,7 +107,7 @@ func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}
Time: le.Time.UTC().Format(time.RFC3339),
RemoteAddr: le.RemoteAddr,
Hijacked: &hijacked,
- Qtype: dns.TypeToString[le.Qtype],
+ Qtype: dnsutil.TypeToString[le.Qtype],
Question: le.Question,
Answers: le.Answers,
})