diff options
-rw-r--r-- | cmd/zdns/main.go | 44 | ||||
-rw-r--r-- | config.go | 8 | ||||
-rw-r--r-- | dns/proxy.go | 29 | ||||
-rw-r--r-- | dns/proxy_test.go | 19 | ||||
-rw-r--r-- | http/http.go | 21 | ||||
-rw-r--r-- | http/http_test.go | 15 | ||||
-rw-r--r-- | server.go | 4 | ||||
-rw-r--r-- | server_test.go | 22 | ||||
-rw-r--r-- | signal/signal.go | 3 | ||||
-rw-r--r-- | signal/signal_test.go | 8 | ||||
-rw-r--r-- | sql/logger.go (renamed from log/logger.go) | 72 | ||||
-rw-r--r-- | sql/logger_test.go (renamed from log/logger_test.go) | 41 |
12 files changed, 141 insertions, 145 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index d116fff..bcc2934 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io" + "log" "os" "path/filepath" "sync" @@ -14,8 +15,8 @@ import ( "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/dns/dnsutil" "github.com/mpolden/zdns/http" - "github.com/mpolden/zdns/log" "github.com/mpolden/zdns/signal" + "github.com/mpolden/zdns/sql" ) const ( @@ -75,42 +76,45 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) config, err := readConfig(*confFile) fatal(err) - // Logger - logger, err := log.New(out, logPrefix, log.RecordOptions{ - Mode: config.DNS.LogMode, - Database: config.DNS.LogDatabase, - TTL: config.DNS.LogTTL, - }) - fatal(err) - - // Signal handling + // Logging and signal handling + logger := log.New(out, logPrefix, log.Lshortfile) sigHandler := signal.NewHandler(sig, logger) - sigHandler.OnClose(logger) - // Client - client := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) + // SQL backends + var sqlLogger *sql.Logger + if config.DNS.LogDatabase != "" { + sqlClient, err := sql.New(config.DNS.LogDatabase) + fatal(err) + + // Logger + sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL) + sigHandler.OnClose(sqlLogger) + } + + // DNS client + dnsClient := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) // Cache - var cclient *dnsutil.Client + var cacheDNS *dnsutil.Client if config.DNS.CachePrefetch { - cclient = client + cacheDNS = dnsClient } - cache := cache.New(config.DNS.CacheSize, cclient) + cache := cache.New(config.DNS.CacheSize, cacheDNS) // DNS server - proxy, err := dns.NewProxy(cache, client, logger) + proxy, err := dns.NewProxy(cache, dnsClient, logger, sqlLogger) fatal(err) sigHandler.OnClose(proxy) - dnsSrv, err := zdns.NewServer(logger, proxy, config) + dnsSrv, err := zdns.NewServer(proxy, config, logger) fatal(err) sigHandler.OnReload(dnsSrv) sigHandler.OnClose(dnsSrv) - servers := []server{dnsSrv} + // HTTP server if config.DNS.ListenHTTP != "" { - httpSrv := http.NewServer(logger, cache, config.DNS.ListenHTTP) + httpSrv := http.NewServer(cache, sqlLogger, logger, config.DNS.ListenHTTP) sigHandler.OnClose(httpSrv) servers = append(servers, httpSrv) } @@ -10,7 +10,7 @@ import ( "github.com/BurntSushi/toml" "github.com/mpolden/zdns/hosts" - "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/sql" ) // Config specifies is the zdns configuration parameters. @@ -180,11 +180,11 @@ func (c *Config) load() error { } switch c.DNS.LogModeString { case "": - c.DNS.LogMode = log.ModeDiscard + c.DNS.LogMode = sql.LogDiscard case "all": - c.DNS.LogMode = log.ModeAll + c.DNS.LogMode = sql.LogAll case "hijacked": - c.DNS.LogMode = log.ModeHijacked + c.DNS.LogMode = sql.LogHijacked default: return fmt.Errorf("invalid log mode: %s", c.DNS.LogModeString) } diff --git a/dns/proxy.go b/dns/proxy.go index 90164ab..9e7607e 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "log" "net" "strings" "sync" @@ -32,27 +33,27 @@ type Handler func(*Request) *Reply // Proxy represents a DNS proxy. type Proxy struct { - Handler Handler - cache *cache.Cache - logger logger - server *dns.Server - client *dnsutil.Client - mu sync.RWMutex + Handler Handler + cache *cache.Cache + logger *log.Logger + dnsLogger logger + server *dns.Server + client *dnsutil.Client + mu sync.RWMutex } type logger interface { - Print(...interface{}) - Printf(string, ...interface{}) Record(net.IP, bool, uint16, string, ...string) Close() error } // NewProxy creates a new DNS proxy. -func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger logger) (*Proxy, error) { +func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger *log.Logger, dnsLogger logger) (*Proxy, error) { return &Proxy{ - logger: logger, - cache: cache, - client: client, + logger: logger, + dnsLogger: dnsLogger, + cache: cache, + client: client, }, nil } @@ -129,7 +130,9 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) { default: panic(fmt.Sprintf("unexpected remote address type %T", v)) } - p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...) + if p.dnsLogger != nil { + p.dnsLogger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...) + } w.WriteMsg(msg) } diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 7d2a894..9b6dc35 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -3,6 +3,7 @@ package dns import ( "fmt" "io/ioutil" + "log" "net" "reflect" "sync" @@ -12,7 +13,6 @@ import ( "github.com/miekg/dns" "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns/dnsutil" - "github.com/mpolden/zdns/log" ) type dnsWriter struct{ lastReply *dns.Msg } @@ -63,15 +63,16 @@ func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Dura return r.answer, time.Second, nil } +type testLogger struct{} + +func (l *testLogger) Record(net.IP, bool, uint16, string, ...string) {} + +func (l *testLogger) Close() error { return nil } + func testProxy(t *testing.T) *Proxy { - log, err := log.New(ioutil.Discard, "", log.RecordOptions{}) - if err != nil { - t.Fatal(err) - } - if err != nil { - t.Fatal(err) - } - proxy, err := NewProxy(cache.New(0, nil), nil, log) + logger := log.New(ioutil.Discard, "", 0) + dnsLogger := &testLogger{} + proxy, err := NewProxy(cache.New(0, nil), nil, logger, dnsLogger) if err != nil { t.Fatal(err) } diff --git a/http/http.go b/http/http.go index ac2e9b0..3887f9b 100644 --- a/http/http.go +++ b/http/http.go @@ -2,6 +2,7 @@ package http import ( "context" + "log" "net" "net/http" _ "net/http/pprof" // Registers debug handlers as a side effect. @@ -10,15 +11,16 @@ import ( "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns/dnsutil" - "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/sql" ) // A Server defines paramaters for running an HTTP server. The HTTP server serves an API for inspecting cache contents // and request log. type Server struct { - cache *cache.Cache - logger *log.Logger - server *http.Server + cache *cache.Cache + logger *log.Logger + sqlLogger *sql.Logger + server *http.Server } type entry struct { @@ -39,12 +41,13 @@ type httpError struct { } // NewServer creates a new HTTP server, serving logs from the given logger and listening on addr. -func NewServer(logger *log.Logger, cache *cache.Cache, addr string) *Server { +func NewServer(cache *cache.Cache, sqlLogger *sql.Logger, logger *log.Logger, addr string) *Server { server := &http.Server{Addr: addr} s := &Server{ - cache: cache, - logger: logger, - server: server, + cache: cache, + logger: logger, + sqlLogger: sqlLogger, + server: server, } s.server.Handler = s.handler() return s @@ -94,7 +97,7 @@ func (s *Server) cacheResetHandler(w http.ResponseWriter, r *http.Request) (inte } func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) { - logEntries, err := s.logger.Get(listCountFrom(r)) + logEntries, err := s.sqlLogger.Get(listCountFrom(r)) if err != nil { return nil, &httpError{ err: err, diff --git a/http/http_test.go b/http/http_test.go index 0d43b46..7853b16 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -2,6 +2,7 @@ package http import ( "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -11,7 +12,7 @@ import ( "github.com/miekg/dns" "github.com/mpolden/zdns/cache" - "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/sql" ) func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { @@ -30,12 +31,14 @@ func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { } func testServer() (*httptest.Server, *Server) { - logger, err := log.New(ioutil.Discard, "", log.RecordOptions{Mode: log.ModeAll, Database: ":memory:"}) + db, err := sql.New(":memory:") if err != nil { panic(err) } + stdLogger := log.New(ioutil.Discard, "", 0) + logger := sql.NewLogger(db, sql.LogAll, 0) cache := cache.New(10, nil) - server := Server{logger: logger, cache: cache} + server := Server{logger: stdLogger, sqlLogger: logger, cache: cache} return httptest.NewServer(server.handler()), &server } @@ -76,9 +79,9 @@ func httpDelete(url, body string) (*http.Response, string, error) { func TestRequests(t *testing.T) { httpSrv, srv := testServer() defer httpSrv.Close() - srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101") - srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1") - srv.logger.Close() // Flush + srv.sqlLogger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101") + srv.sqlLogger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1") + srv.sqlLogger.Close() // Flush srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200))) srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201))) @@ -3,6 +3,7 @@ package zdns import ( "fmt" "io" + "log" "net" "net/http" "net/url" @@ -13,7 +14,6 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/hosts" - "github.com/mpolden/zdns/log" ) const ( @@ -37,7 +37,7 @@ type Server struct { } // NewServer returns a new server configured according to config. -func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, error) { +func NewServer(proxy *dns.Proxy, config Config, logger *log.Logger) (*Server, error) { server := &Server{ Config: config, done: make(chan bool, 1), diff --git a/server_test.go b/server_test.go index 6aa16ad..73fc35b 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package zdns import ( "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -13,7 +14,6 @@ import ( "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/hosts" - "github.com/mpolden/zdns/log" ) const hostsFile1 = ` @@ -29,6 +29,12 @@ const hostsFile2 = ` 192.0.2.6 badhost6 ` +type testLogger struct{} + +func (l *testLogger) Record(net.IP, bool, uint16, string, ...string) {} + +func (l *testLogger) Close() error { return nil } + func httpHandler(t *testing.T, response string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(response)); err != nil { @@ -96,15 +102,13 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) { if err := config.load(); err != nil { t.Fatal(err) } - logger, err := log.New(ioutil.Discard, "", log.RecordOptions{}) - if err != nil { - t.Fatal(err) - } - proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger) + logger := log.New(ioutil.Discard, "", 0) + queryLogger := &testLogger{} + proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger, queryLogger) if err != nil { t.Fatal(err) } - srv, err = NewServer(logger, proxy, config) + srv, err = NewServer(proxy, config, logger) if err != nil { defer cleanup() t.Fatal(err) @@ -168,7 +172,7 @@ func TestNonFqdn(t *testing.T) { } func TestHijack(t *testing.T) { - log, _ := log.New(ioutil.Discard, "", log.RecordOptions{}) + logger := log.New(ioutil.Discard, "", 0) s := &Server{ Config: Config{}, hosts: hosts.Hosts{ @@ -177,7 +181,7 @@ func TestHijack(t *testing.T) { {IP: net.ParseIP("2001:db8::1")}, }, }, - logger: log, + logger: logger, } var tests = []struct { diff --git a/signal/signal.go b/signal/signal.go index 96ef5f7..d2f50c2 100644 --- a/signal/signal.go +++ b/signal/signal.go @@ -2,11 +2,10 @@ package signal import ( "io" + "log" "os" "os/signal" "syscall" - - "github.com/mpolden/zdns/log" ) // Reloader is the interface for types that need to act on a reload signal. diff --git a/signal/signal_test.go b/signal/signal_test.go index 7010ebd..fe0cef5 100644 --- a/signal/signal_test.go +++ b/signal/signal_test.go @@ -2,13 +2,12 @@ package signal import ( "io/ioutil" + "log" "os" "sync" "syscall" "testing" "time" - - "github.com/mpolden/zdns/log" ) type reloaderCloser struct { @@ -47,10 +46,7 @@ func (rc *reloaderCloser) reset() { } func TestHandler(t *testing.T) { - logger, err := log.New(ioutil.Discard, "", log.RecordOptions{}) - if err != nil { - t.Fatal(err) - } + logger := log.New(ioutil.Discard, "", 0) h := NewHandler(make(chan os.Signal, 1), logger) rc := &reloaderCloser{} diff --git a/log/logger.go b/sql/logger.go index d8dc4af..f6c7215 100644 --- a/log/logger.go +++ b/sql/logger.go @@ -1,32 +1,29 @@ -package log +package sql import ( - "io" "log" "net" "sync" "time" - - "github.com/mpolden/zdns/sql" ) const ( - // ModeDiscard disables logging of DNS requests. - ModeDiscard = iota - // ModeAll logs all DNS requests. - ModeAll - // ModeHijacked only logs hijacked DNS requests. - ModeHijacked + // LogDiscard disables logging of DNS requests. + LogDiscard = iota + // LogAll logs all DNS requests. + LogAll + // LogHijacked only logs hijacked DNS requests. + LogHijacked ) -// Logger wraps a standard log.Logger and an optional log database. +// Logger is a logs DNS requests to a SQL database. type Logger struct { - *log.Logger - mode int - queue chan Entry - db *sql.Client - wg sync.WaitGroup - now func() time.Time + mode int + queue chan Entry + db *Client + wg sync.WaitGroup + now func() time.Time + Logger *log.Logger } // RecordOptions configures recording of DNS requests. @@ -46,25 +43,19 @@ type Entry struct { Answers []string } -// New creates a new logger, writing log output to writer w prefixed with prefix. Persisted logging behaviour is -// controller by options. -func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) { +// NewLogger creates a new logger. Persisted entries are kept according to ttl. +func NewLogger(db *Client, mode int, ttl time.Duration) *Logger { logger := &Logger{ - Logger: log.New(w, prefix, 0), - queue: make(chan Entry, 100), - now: time.Now, - mode: options.Mode, + db: db, + queue: make(chan Entry, 100), + now: time.Now, + mode: mode, } - var err error - if options.Database != "" { - logger.db, err = sql.New(options.Database) - if err != nil { - return nil, err - } + if mode != LogDiscard { + logger.wg.Add(1) + go logger.readQueue(ttl) } - logger.wg.Add(1) - go logger.readQueue(options.TTL) - return logger, nil + return logger } // Close consumes any outstanding log requests and closes the logger. @@ -79,10 +70,10 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question if l.db == nil { return } - if l.mode == ModeDiscard { + if l.mode == LogDiscard { return } - if l.mode == ModeHijacked && !hijacked { + if l.mode == LogHijacked && !hijacked { return } l.queue <- Entry{ @@ -124,16 +115,23 @@ func (l *Logger) Get(n int) ([]Entry, error) { return entries, nil } +func (l *Logger) printf(format string, v ...interface{}) { + if l.Logger == nil { + return + } + l.Logger.Printf(format, v...) +} + func (l *Logger) readQueue(ttl time.Duration) { defer l.wg.Done() for e := range l.queue { if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { - l.Printf("write failed: %+v: %s", e, err) + l.printf("write failed: %+v: %s", e, err) } if ttl > 0 { t := l.now().Add(-ttl) if err := l.db.DeleteLogBefore(t); err != nil { - l.Printf("deleting log entries before %v failed: %s", t, err) + l.printf("deleting log entries before %v failed: %s", t, err) } } } diff --git a/log/logger_test.go b/sql/logger_test.go index f3a91f1..fea858b 100644 --- a/log/logger_test.go +++ b/sql/logger_test.go @@ -1,18 +1,15 @@ -package log +package sql import ( "net" - "os" "reflect" "testing" "time" ) func TestRecord(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll}) - if err != nil { - t.Fatal(err) - } + client := testClient() + logger := NewLogger(client, LogAll, 0) logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2") // Flush queue if err := logger.Close(); err != nil { @@ -37,18 +34,15 @@ func TestMode(t *testing.T) { mode int log bool }{ - {badHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true}, - {goodHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true}, - {badHost, net.IPv4(192, 0, 2, 100), true, ModeHijacked, true}, - {goodHost, net.IPv4(192, 0, 2, 100), false, ModeHijacked, false}, - {badHost, net.IPv4(192, 0, 2, 100), true, ModeDiscard, false}, - {goodHost, net.IPv4(192, 0, 2, 100), false, ModeDiscard, false}, + {badHost, net.IPv4(192, 0, 2, 100), true, LogAll, true}, + {goodHost, net.IPv4(192, 0, 2, 100), true, LogAll, true}, + {badHost, net.IPv4(192, 0, 2, 100), true, LogHijacked, true}, + {goodHost, net.IPv4(192, 0, 2, 100), false, LogHijacked, false}, + {badHost, net.IPv4(192, 0, 2, 100), true, LogDiscard, false}, + {goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard, false}, } for i, tt := range tests { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: tt.mode}) - if err != nil { - t.Fatal(err) - } + logger := NewLogger(testClient(), tt.mode, 0) logger.mode = tt.mode logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question) if err := logger.Close(); err != nil { // Flush @@ -65,10 +59,7 @@ func TestMode(t *testing.T) { } func TestAnswerMerging(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll}) - if err != nil { - t.Fatal(err) - } + logger := NewLogger(testClient(), LogAll, 0) now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) logger.now = func() time.Time { return now } logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2") @@ -104,14 +95,7 @@ func TestAnswerMerging(t *testing.T) { } func TestLogPruning(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{ - Mode: ModeAll, - Database: ":memory:", - TTL: time.Hour, - }) - if err != nil { - t.Fatal(err) - } + logger := NewLogger(testClient(), LogAll, time.Hour) defer logger.Close() tt := time.Now() logger.now = func() time.Time { return tt } @@ -120,6 +104,7 @@ func TestLogPruning(t *testing.T) { // Wait until queue is flushed ts := time.Now() var entries []Entry + var err error for len(entries) == 0 { entries, err = logger.Get(1) if err != nil { |