From 541ec0a86c0d29c469bc48bce34493daf20a0ee4 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sat, 11 Jan 2020 13:22:57 +0100 Subject: Decouple std logger and query logger --- cmd/zdns/main.go | 44 ++++++++------- config.go | 8 +-- dns/proxy.go | 29 +++++----- dns/proxy_test.go | 19 ++++--- http/http.go | 21 ++++--- http/http_test.go | 15 +++-- log/logger.go | 140 ----------------------------------------------- log/logger_test.go | 148 -------------------------------------------------- server.go | 4 +- server_test.go | 22 +++++--- signal/signal.go | 3 +- signal/signal_test.go | 8 +-- sql/logger.go | 138 ++++++++++++++++++++++++++++++++++++++++++++++ sql/logger_test.go | 133 +++++++++++++++++++++++++++++++++++++++++++++ 14 files changed, 364 insertions(+), 368 deletions(-) delete mode 100644 log/logger.go delete mode 100644 log/logger_test.go create mode 100644 sql/logger.go create mode 100644 sql/logger_test.go diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index d116fff..bcc2934 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io" + "log" "os" "path/filepath" "sync" @@ -14,8 +15,8 @@ import ( "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/dns/dnsutil" "github.com/mpolden/zdns/http" - "github.com/mpolden/zdns/log" "github.com/mpolden/zdns/signal" + "github.com/mpolden/zdns/sql" ) const ( @@ -75,42 +76,45 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) config, err := readConfig(*confFile) fatal(err) - // Logger - logger, err := log.New(out, logPrefix, log.RecordOptions{ - Mode: config.DNS.LogMode, - Database: config.DNS.LogDatabase, - TTL: config.DNS.LogTTL, - }) - fatal(err) - - // Signal handling + // Logging and signal handling + logger := log.New(out, logPrefix, log.Lshortfile) sigHandler := signal.NewHandler(sig, logger) - sigHandler.OnClose(logger) - // Client - client := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) + // SQL backends + var sqlLogger *sql.Logger + if config.DNS.LogDatabase != "" { + sqlClient, err := sql.New(config.DNS.LogDatabase) + fatal(err) + + // Logger + sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL) + sigHandler.OnClose(sqlLogger) + } + + // DNS client + dnsClient := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) // Cache - var cclient *dnsutil.Client + var cacheDNS *dnsutil.Client if config.DNS.CachePrefetch { - cclient = client + cacheDNS = dnsClient } - cache := cache.New(config.DNS.CacheSize, cclient) + cache := cache.New(config.DNS.CacheSize, cacheDNS) // DNS server - proxy, err := dns.NewProxy(cache, client, logger) + proxy, err := dns.NewProxy(cache, dnsClient, logger, sqlLogger) fatal(err) sigHandler.OnClose(proxy) - dnsSrv, err := zdns.NewServer(logger, proxy, config) + dnsSrv, err := zdns.NewServer(proxy, config, logger) fatal(err) sigHandler.OnReload(dnsSrv) sigHandler.OnClose(dnsSrv) - servers := []server{dnsSrv} + // HTTP server if config.DNS.ListenHTTP != "" { - httpSrv := http.NewServer(logger, cache, config.DNS.ListenHTTP) + httpSrv := http.NewServer(cache, sqlLogger, logger, config.DNS.ListenHTTP) sigHandler.OnClose(httpSrv) servers = append(servers, httpSrv) } diff --git a/config.go b/config.go index 3d581f1..b691291 100644 --- a/config.go +++ b/config.go @@ -10,7 +10,7 @@ import ( "github.com/BurntSushi/toml" "github.com/mpolden/zdns/hosts" - "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/sql" ) // Config specifies is the zdns configuration parameters. @@ -180,11 +180,11 @@ func (c *Config) load() error { } switch c.DNS.LogModeString { case "": - c.DNS.LogMode = log.ModeDiscard + c.DNS.LogMode = sql.LogDiscard case "all": - c.DNS.LogMode = log.ModeAll + c.DNS.LogMode = sql.LogAll case "hijacked": - c.DNS.LogMode = log.ModeHijacked + c.DNS.LogMode = sql.LogHijacked default: return fmt.Errorf("invalid log mode: %s", c.DNS.LogModeString) } diff --git a/dns/proxy.go b/dns/proxy.go index 90164ab..9e7607e 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -2,6 +2,7 @@ package dns import ( "fmt" + "log" "net" "strings" "sync" @@ -32,27 +33,27 @@ type Handler func(*Request) *Reply // Proxy represents a DNS proxy. type Proxy struct { - Handler Handler - cache *cache.Cache - logger logger - server *dns.Server - client *dnsutil.Client - mu sync.RWMutex + Handler Handler + cache *cache.Cache + logger *log.Logger + dnsLogger logger + server *dns.Server + client *dnsutil.Client + mu sync.RWMutex } type logger interface { - Print(...interface{}) - Printf(string, ...interface{}) Record(net.IP, bool, uint16, string, ...string) Close() error } // NewProxy creates a new DNS proxy. -func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger logger) (*Proxy, error) { +func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger *log.Logger, dnsLogger logger) (*Proxy, error) { return &Proxy{ - logger: logger, - cache: cache, - client: client, + logger: logger, + dnsLogger: dnsLogger, + cache: cache, + client: client, }, nil } @@ -129,7 +130,9 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) { default: panic(fmt.Sprintf("unexpected remote address type %T", v)) } - p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...) + if p.dnsLogger != nil { + p.dnsLogger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...) + } w.WriteMsg(msg) } diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 7d2a894..9b6dc35 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -3,6 +3,7 @@ package dns import ( "fmt" "io/ioutil" + "log" "net" "reflect" "sync" @@ -12,7 +13,6 @@ import ( "github.com/miekg/dns" "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns/dnsutil" - "github.com/mpolden/zdns/log" ) type dnsWriter struct{ lastReply *dns.Msg } @@ -63,15 +63,16 @@ func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Dura return r.answer, time.Second, nil } +type testLogger struct{} + +func (l *testLogger) Record(net.IP, bool, uint16, string, ...string) {} + +func (l *testLogger) Close() error { return nil } + func testProxy(t *testing.T) *Proxy { - log, err := log.New(ioutil.Discard, "", log.RecordOptions{}) - if err != nil { - t.Fatal(err) - } - if err != nil { - t.Fatal(err) - } - proxy, err := NewProxy(cache.New(0, nil), nil, log) + logger := log.New(ioutil.Discard, "", 0) + dnsLogger := &testLogger{} + proxy, err := NewProxy(cache.New(0, nil), nil, logger, dnsLogger) if err != nil { t.Fatal(err) } diff --git a/http/http.go b/http/http.go index ac2e9b0..3887f9b 100644 --- a/http/http.go +++ b/http/http.go @@ -2,6 +2,7 @@ package http import ( "context" + "log" "net" "net/http" _ "net/http/pprof" // Registers debug handlers as a side effect. @@ -10,15 +11,16 @@ import ( "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns/dnsutil" - "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/sql" ) // A Server defines paramaters for running an HTTP server. The HTTP server serves an API for inspecting cache contents // and request log. type Server struct { - cache *cache.Cache - logger *log.Logger - server *http.Server + cache *cache.Cache + logger *log.Logger + sqlLogger *sql.Logger + server *http.Server } type entry struct { @@ -39,12 +41,13 @@ type httpError struct { } // NewServer creates a new HTTP server, serving logs from the given logger and listening on addr. -func NewServer(logger *log.Logger, cache *cache.Cache, addr string) *Server { +func NewServer(cache *cache.Cache, sqlLogger *sql.Logger, logger *log.Logger, addr string) *Server { server := &http.Server{Addr: addr} s := &Server{ - cache: cache, - logger: logger, - server: server, + cache: cache, + logger: logger, + sqlLogger: sqlLogger, + server: server, } s.server.Handler = s.handler() return s @@ -94,7 +97,7 @@ func (s *Server) cacheResetHandler(w http.ResponseWriter, r *http.Request) (inte } func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) { - logEntries, err := s.logger.Get(listCountFrom(r)) + logEntries, err := s.sqlLogger.Get(listCountFrom(r)) if err != nil { return nil, &httpError{ err: err, diff --git a/http/http_test.go b/http/http_test.go index 0d43b46..7853b16 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -2,6 +2,7 @@ package http import ( "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -11,7 +12,7 @@ import ( "github.com/miekg/dns" "github.com/mpolden/zdns/cache" - "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/sql" ) func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { @@ -30,12 +31,14 @@ func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { } func testServer() (*httptest.Server, *Server) { - logger, err := log.New(ioutil.Discard, "", log.RecordOptions{Mode: log.ModeAll, Database: ":memory:"}) + db, err := sql.New(":memory:") if err != nil { panic(err) } + stdLogger := log.New(ioutil.Discard, "", 0) + logger := sql.NewLogger(db, sql.LogAll, 0) cache := cache.New(10, nil) - server := Server{logger: logger, cache: cache} + server := Server{logger: stdLogger, sqlLogger: logger, cache: cache} return httptest.NewServer(server.handler()), &server } @@ -76,9 +79,9 @@ func httpDelete(url, body string) (*http.Response, string, error) { func TestRequests(t *testing.T) { httpSrv, srv := testServer() defer httpSrv.Close() - srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101") - srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1") - srv.logger.Close() // Flush + srv.sqlLogger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101") + srv.sqlLogger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1") + srv.sqlLogger.Close() // Flush srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200))) srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201))) diff --git a/log/logger.go b/log/logger.go deleted file mode 100644 index d8dc4af..0000000 --- a/log/logger.go +++ /dev/null @@ -1,140 +0,0 @@ -package log - -import ( - "io" - "log" - "net" - "sync" - "time" - - "github.com/mpolden/zdns/sql" -) - -const ( - // ModeDiscard disables logging of DNS requests. - ModeDiscard = iota - // ModeAll logs all DNS requests. - ModeAll - // ModeHijacked only logs hijacked DNS requests. - ModeHijacked -) - -// Logger wraps a standard log.Logger and an optional log database. -type Logger struct { - *log.Logger - mode int - queue chan Entry - db *sql.Client - wg sync.WaitGroup - now func() time.Time -} - -// RecordOptions configures recording of DNS requests. -type RecordOptions struct { - Database string - Mode int - TTL time.Duration -} - -// Entry represents a DNS request log entry. -type Entry struct { - Time time.Time - RemoteAddr net.IP - Hijacked bool - Qtype uint16 - Question string - Answers []string -} - -// New creates a new logger, writing log output to writer w prefixed with prefix. Persisted logging behaviour is -// controller by options. -func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) { - logger := &Logger{ - Logger: log.New(w, prefix, 0), - queue: make(chan Entry, 100), - now: time.Now, - mode: options.Mode, - } - var err error - if options.Database != "" { - logger.db, err = sql.New(options.Database) - if err != nil { - return nil, err - } - } - logger.wg.Add(1) - go logger.readQueue(options.TTL) - return logger, nil -} - -// Close consumes any outstanding log requests and closes the logger. -func (l *Logger) Close() error { - close(l.queue) - l.wg.Wait() - return nil -} - -// Record records the given DNS request to the log database. -func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) { - if l.db == nil { - return - } - if l.mode == ModeDiscard { - return - } - if l.mode == ModeHijacked && !hijacked { - return - } - l.queue <- Entry{ - Time: l.now(), - RemoteAddr: remoteAddr, - Hijacked: hijacked, - Qtype: qtype, - Question: question, - Answers: answers, - } -} - -// Get returns the n most recent persisted log entries. -func (l *Logger) Get(n int) ([]Entry, error) { - logEntries, err := l.db.ReadLog(n) - if err != nil { - return nil, err - } - ids := make(map[int64]*Entry) - entries := make([]Entry, 0, len(logEntries)) - for _, le := range logEntries { - entry, ok := ids[le.ID] - if !ok { - newEntry := Entry{ - Time: time.Unix(le.Time, 0).UTC(), - RemoteAddr: le.RemoteAddr, - Hijacked: le.Hijacked, - Qtype: le.Qtype, - Question: le.Question, - } - entries = append(entries, newEntry) - entry = &entries[len(entries)-1] - ids[le.ID] = entry - } - if le.Answer != "" { - entry.Answers = append(entry.Answers, le.Answer) - } - } - return entries, nil -} - -func (l *Logger) readQueue(ttl time.Duration) { - defer l.wg.Done() - for e := range l.queue { - if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { - l.Printf("write failed: %+v: %s", e, err) - } - if ttl > 0 { - t := l.now().Add(-ttl) - if err := l.db.DeleteLogBefore(t); err != nil { - l.Printf("deleting log entries before %v failed: %s", t, err) - } - } - } -} diff --git a/log/logger_test.go b/log/logger_test.go deleted file mode 100644 index f3a91f1..0000000 --- a/log/logger_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package log - -import ( - "net" - "os" - "reflect" - "testing" - "time" -) - -func TestRecord(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll}) - if err != nil { - t.Fatal(err) - } - logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2") - // Flush queue - if err := logger.Close(); err != nil { - t.Fatal(err) - } - logEntries, err := logger.db.ReadLog(1) - if err != nil { - t.Fatal(err) - } - if want, got := 2, len(logEntries); want != got { - t.Errorf("len(entries) = %d, want %d", got, want) - } -} - -func TestMode(t *testing.T) { - badHost := "badhost1." - goodHost := "goodhost1." - var tests = []struct { - question string - remoteAddr net.IP - hijacked bool - mode int - log bool - }{ - {badHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true}, - {goodHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true}, - {badHost, net.IPv4(192, 0, 2, 100), true, ModeHijacked, true}, - {goodHost, net.IPv4(192, 0, 2, 100), false, ModeHijacked, false}, - {badHost, net.IPv4(192, 0, 2, 100), true, ModeDiscard, false}, - {goodHost, net.IPv4(192, 0, 2, 100), false, ModeDiscard, false}, - } - for i, tt := range tests { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: tt.mode}) - if err != nil { - t.Fatal(err) - } - logger.mode = tt.mode - logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question) - if err := logger.Close(); err != nil { // Flush - t.Fatal(err) - } - entries, err := logger.Get(1) - if err != nil { - t.Fatal(err) - } - if len(entries) > 0 != tt.log { - t.Errorf("#%d: question %q (hijacked=%t) should be logged in mode %d", i, tt.question, tt.hijacked, tt.mode) - } - } -} - -func TestAnswerMerging(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll}) - if err != nil { - t.Fatal(err) - } - now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) - logger.now = func() time.Time { return now } - logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2") - logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "2.example.com.") - // Flush queue - if err := logger.Close(); err != nil { - t.Fatal(err) - } - // Multi-answer log entries are merged - got, err := logger.Get(2) - if err != nil { - t.Fatal(err) - } - want := []Entry{ - { - Time: now, - RemoteAddr: net.IPv4(192, 0, 2, 100), - Hijacked: true, - Qtype: 1, - Question: "example.com.", - Answers: []string{"192.0.2.2", "192.0.2.1"}, - }, - { - Time: now, - RemoteAddr: net.IPv4(192, 0, 2, 100), - Hijacked: true, - Qtype: 1, - Question: "2.example.com.", - }} - if !reflect.DeepEqual(want, got) { - t.Errorf("Get(1) = %+v, want %+v", got, want) - } -} - -func TestLogPruning(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{ - Mode: ModeAll, - Database: ":memory:", - TTL: time.Hour, - }) - if err != nil { - t.Fatal(err) - } - defer logger.Close() - tt := time.Now() - logger.now = func() time.Time { return tt } - logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1") - - // Wait until queue is flushed - ts := time.Now() - var entries []Entry - for len(entries) == 0 { - entries, err = logger.Get(1) - if err != nil { - t.Fatal(err) - } - time.Sleep(10 * time.Millisecond) - if time.Since(ts) > 2*time.Second { - t.Fatal("timed out waiting for log entry to be written") - } - } - - // Advance time beyond log TTL - tt = tt.Add(time.Hour).Add(time.Second) - // Trigger pruning by recording another entry - logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "2.example.com.", "192.0.2.2") - for len(entries) > 1 { - entries, err = logger.Get(2) - if err != nil { - t.Fatal(err) - } - time.Sleep(10 * time.Millisecond) - if time.Since(ts) > 2*time.Second { - t.Fatal("timed out waiting for log entry to be removed") - } - } -} diff --git a/server.go b/server.go index 79a2447..5fc36dd 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package zdns import ( "fmt" "io" + "log" "net" "net/http" "net/url" @@ -13,7 +14,6 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/hosts" - "github.com/mpolden/zdns/log" ) const ( @@ -37,7 +37,7 @@ type Server struct { } // NewServer returns a new server configured according to config. -func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, error) { +func NewServer(proxy *dns.Proxy, config Config, logger *log.Logger) (*Server, error) { server := &Server{ Config: config, done: make(chan bool, 1), diff --git a/server_test.go b/server_test.go index 6aa16ad..73fc35b 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package zdns import ( "io/ioutil" + "log" "net" "net/http" "net/http/httptest" @@ -13,7 +14,6 @@ import ( "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/hosts" - "github.com/mpolden/zdns/log" ) const hostsFile1 = ` @@ -29,6 +29,12 @@ const hostsFile2 = ` 192.0.2.6 badhost6 ` +type testLogger struct{} + +func (l *testLogger) Record(net.IP, bool, uint16, string, ...string) {} + +func (l *testLogger) Close() error { return nil } + func httpHandler(t *testing.T, response string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(response)); err != nil { @@ -96,15 +102,13 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) { if err := config.load(); err != nil { t.Fatal(err) } - logger, err := log.New(ioutil.Discard, "", log.RecordOptions{}) - if err != nil { - t.Fatal(err) - } - proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger) + logger := log.New(ioutil.Discard, "", 0) + queryLogger := &testLogger{} + proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger, queryLogger) if err != nil { t.Fatal(err) } - srv, err = NewServer(logger, proxy, config) + srv, err = NewServer(proxy, config, logger) if err != nil { defer cleanup() t.Fatal(err) @@ -168,7 +172,7 @@ func TestNonFqdn(t *testing.T) { } func TestHijack(t *testing.T) { - log, _ := log.New(ioutil.Discard, "", log.RecordOptions{}) + logger := log.New(ioutil.Discard, "", 0) s := &Server{ Config: Config{}, hosts: hosts.Hosts{ @@ -177,7 +181,7 @@ func TestHijack(t *testing.T) { {IP: net.ParseIP("2001:db8::1")}, }, }, - logger: log, + logger: logger, } var tests = []struct { diff --git a/signal/signal.go b/signal/signal.go index 96ef5f7..d2f50c2 100644 --- a/signal/signal.go +++ b/signal/signal.go @@ -2,11 +2,10 @@ package signal import ( "io" + "log" "os" "os/signal" "syscall" - - "github.com/mpolden/zdns/log" ) // Reloader is the interface for types that need to act on a reload signal. diff --git a/signal/signal_test.go b/signal/signal_test.go index 7010ebd..fe0cef5 100644 --- a/signal/signal_test.go +++ b/signal/signal_test.go @@ -2,13 +2,12 @@ package signal import ( "io/ioutil" + "log" "os" "sync" "syscall" "testing" "time" - - "github.com/mpolden/zdns/log" ) type reloaderCloser struct { @@ -47,10 +46,7 @@ func (rc *reloaderCloser) reset() { } func TestHandler(t *testing.T) { - logger, err := log.New(ioutil.Discard, "", log.RecordOptions{}) - if err != nil { - t.Fatal(err) - } + logger := log.New(ioutil.Discard, "", 0) h := NewHandler(make(chan os.Signal, 1), logger) rc := &reloaderCloser{} diff --git a/sql/logger.go b/sql/logger.go new file mode 100644 index 0000000..f6c7215 --- /dev/null +++ b/sql/logger.go @@ -0,0 +1,138 @@ +package sql + +import ( + "log" + "net" + "sync" + "time" +) + +const ( + // LogDiscard disables logging of DNS requests. + LogDiscard = iota + // LogAll logs all DNS requests. + LogAll + // LogHijacked only logs hijacked DNS requests. + LogHijacked +) + +// Logger is a logs DNS requests to a SQL database. +type Logger struct { + mode int + queue chan Entry + db *Client + wg sync.WaitGroup + now func() time.Time + Logger *log.Logger +} + +// RecordOptions configures recording of DNS requests. +type RecordOptions struct { + Database string + Mode int + TTL time.Duration +} + +// Entry represents a DNS request log entry. +type Entry struct { + Time time.Time + RemoteAddr net.IP + Hijacked bool + Qtype uint16 + Question string + Answers []string +} + +// NewLogger creates a new logger. Persisted entries are kept according to ttl. +func NewLogger(db *Client, mode int, ttl time.Duration) *Logger { + logger := &Logger{ + db: db, + queue: make(chan Entry, 100), + now: time.Now, + mode: mode, + } + if mode != LogDiscard { + logger.wg.Add(1) + go logger.readQueue(ttl) + } + return logger +} + +// Close consumes any outstanding log requests and closes the logger. +func (l *Logger) Close() error { + close(l.queue) + l.wg.Wait() + return nil +} + +// Record records the given DNS request to the log database. +func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) { + if l.db == nil { + return + } + if l.mode == LogDiscard { + return + } + if l.mode == LogHijacked && !hijacked { + return + } + l.queue <- Entry{ + Time: l.now(), + RemoteAddr: remoteAddr, + Hijacked: hijacked, + Qtype: qtype, + Question: question, + Answers: answers, + } +} + +// Get returns the n most recent persisted log entries. +func (l *Logger) Get(n int) ([]Entry, error) { + logEntries, err := l.db.ReadLog(n) + if err != nil { + return nil, err + } + ids := make(map[int64]*Entry) + entries := make([]Entry, 0, len(logEntries)) + for _, le := range logEntries { + entry, ok := ids[le.ID] + if !ok { + newEntry := Entry{ + Time: time.Unix(le.Time, 0).UTC(), + RemoteAddr: le.RemoteAddr, + Hijacked: le.Hijacked, + Qtype: le.Qtype, + Question: le.Question, + } + entries = append(entries, newEntry) + entry = &entries[len(entries)-1] + ids[le.ID] = entry + } + if le.Answer != "" { + entry.Answers = append(entry.Answers, le.Answer) + } + } + return entries, nil +} + +func (l *Logger) printf(format string, v ...interface{}) { + if l.Logger == nil { + return + } + l.Logger.Printf(format, v...) +} + +func (l *Logger) readQueue(ttl time.Duration) { + defer l.wg.Done() + for e := range l.queue { + if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { + l.printf("write failed: %+v: %s", e, err) + } + if ttl > 0 { + t := l.now().Add(-ttl) + if err := l.db.DeleteLogBefore(t); err != nil { + l.printf("deleting log entries before %v failed: %s", t, err) + } + } + } +} diff --git a/sql/logger_test.go b/sql/logger_test.go new file mode 100644 index 0000000..fea858b --- /dev/null +++ b/sql/logger_test.go @@ -0,0 +1,133 @@ +package sql + +import ( + "net" + "reflect" + "testing" + "time" +) + +func TestRecord(t *testing.T) { + client := testClient() + logger := NewLogger(client, LogAll, 0) + logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2") + // Flush queue + if err := logger.Close(); err != nil { + t.Fatal(err) + } + logEntries, err := logger.db.ReadLog(1) + if err != nil { + t.Fatal(err) + } + if want, got := 2, len(logEntries); want != got { + t.Errorf("len(entries) = %d, want %d", got, want) + } +} + +func TestMode(t *testing.T) { + badHost := "badhost1." + goodHost := "goodhost1." + var tests = []struct { + question string + remoteAddr net.IP + hijacked bool + mode int + log bool + }{ + {badHost, net.IPv4(192, 0, 2, 100), true, LogAll, true}, + {goodHost, net.IPv4(192, 0, 2, 100), true, LogAll, true}, + {badHost, net.IPv4(192, 0, 2, 100), true, LogHijacked, true}, + {goodHost, net.IPv4(192, 0, 2, 100), false, LogHijacked, false}, + {badHost, net.IPv4(192, 0, 2, 100), true, LogDiscard, false}, + {goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard, false}, + } + for i, tt := range tests { + logger := NewLogger(testClient(), tt.mode, 0) + logger.mode = tt.mode + logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question) + if err := logger.Close(); err != nil { // Flush + t.Fatal(err) + } + entries, err := logger.Get(1) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 != tt.log { + t.Errorf("#%d: question %q (hijacked=%t) should be logged in mode %d", i, tt.question, tt.hijacked, tt.mode) + } + } +} + +func TestAnswerMerging(t *testing.T) { + logger := NewLogger(testClient(), LogAll, 0) + now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) + logger.now = func() time.Time { return now } + logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2") + logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "2.example.com.") + // Flush queue + if err := logger.Close(); err != nil { + t.Fatal(err) + } + // Multi-answer log entries are merged + got, err := logger.Get(2) + if err != nil { + t.Fatal(err) + } + want := []Entry{ + { + Time: now, + RemoteAddr: net.IPv4(192, 0, 2, 100), + Hijacked: true, + Qtype: 1, + Question: "example.com.", + Answers: []string{"192.0.2.2", "192.0.2.1"}, + }, + { + Time: now, + RemoteAddr: net.IPv4(192, 0, 2, 100), + Hijacked: true, + Qtype: 1, + Question: "2.example.com.", + }} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get(1) = %+v, want %+v", got, want) + } +} + +func TestLogPruning(t *testing.T) { + logger := NewLogger(testClient(), LogAll, time.Hour) + defer logger.Close() + tt := time.Now() + logger.now = func() time.Time { return tt } + logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1") + + // Wait until queue is flushed + ts := time.Now() + var entries []Entry + var err error + for len(entries) == 0 { + entries, err = logger.Get(1) + if err != nil { + t.Fatal(err) + } + time.Sleep(10 * time.Millisecond) + if time.Since(ts) > 2*time.Second { + t.Fatal("timed out waiting for log entry to be written") + } + } + + // Advance time beyond log TTL + tt = tt.Add(time.Hour).Add(time.Second) + // Trigger pruning by recording another entry + logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "2.example.com.", "192.0.2.2") + for len(entries) > 1 { + entries, err = logger.Get(2) + if err != nil { + t.Fatal(err) + } + time.Sleep(10 * time.Millisecond) + if time.Since(ts) > 2*time.Second { + t.Fatal("timed out waiting for log entry to be removed") + } + } +} -- cgit v1.2.3