diff options
-rw-r--r-- | cmd/zdns/main.go | 2 | ||||
-rw-r--r-- | http/http_test.go | 2 | ||||
-rw-r--r-- | sql/logger.go | 32 | ||||
-rw-r--r-- | sql/logger_test.go | 12 |
4 files changed, 23 insertions, 25 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 0fc935d..87b3994 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -90,7 +90,7 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) fatal(err) // Logger - sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL) + sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL, logger) sigHandler.OnClose(sqlLogger) // Cache diff --git a/http/http_test.go b/http/http_test.go index 7853b16..fa39549 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -36,7 +36,7 @@ func testServer() (*httptest.Server, *Server) { panic(err) } stdLogger := log.New(ioutil.Discard, "", 0) - logger := sql.NewLogger(db, sql.LogAll, 0) + logger := sql.NewLogger(db, sql.LogAll, 0, stdLogger) cache := cache.New(10, nil) server := Server{logger: stdLogger, sqlLogger: logger, cache: cache} return httptest.NewServer(server.handler()), &server diff --git a/sql/logger.go b/sql/logger.go index cca3de8..e226103 100644 --- a/sql/logger.go +++ b/sql/logger.go @@ -23,7 +23,7 @@ type Logger struct { db *Client wg sync.WaitGroup now func() time.Time - Logger *log.Logger + logger *log.Logger } // Entry represents a DNS request log entry. @@ -37,18 +37,19 @@ type Entry struct { } // NewLogger creates a new logger. Persisted entries are kept according to ttl. -func NewLogger(db *Client, mode int, ttl time.Duration) *Logger { - logger := &Logger{ - db: db, - queue: make(chan Entry, 100), - now: time.Now, - mode: mode, +func NewLogger(db *Client, mode int, ttl time.Duration, logger *log.Logger) *Logger { + l := &Logger{ + db: db, + queue: make(chan Entry, 100), + now: time.Now, + mode: mode, + logger: logger, } if mode != LogDiscard { - logger.wg.Add(1) - go logger.readQueue(ttl) + l.wg.Add(1) + go l.readQueue(ttl) } - return logger + return l } // Close consumes any outstanding log requests and closes the logger. @@ -108,23 +109,16 @@ func (l *Logger) Get(n int) ([]Entry, error) { return entries, nil } -func (l *Logger) printf(format string, v ...interface{}) { - if l.Logger == nil { - return - } - l.Logger.Printf(format, v...) -} - func (l *Logger) readQueue(ttl time.Duration) { defer l.wg.Done() for e := range l.queue { if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { - l.printf("write failed: %+v: %s", e, err) + l.logger.Printf("write failed: %+v: %s", e, err) } if ttl > 0 { t := l.now().Add(-ttl) if err := l.db.DeleteLogBefore(t); err != nil { - l.printf("deleting log entries before %v failed: %s", t, err) + l.logger.Printf("deleting log entries before %v failed: %s", t, err) } } } diff --git a/sql/logger_test.go b/sql/logger_test.go index fea858b..776de31 100644 --- a/sql/logger_test.go +++ b/sql/logger_test.go @@ -1,15 +1,19 @@ package sql import ( + "io/ioutil" + "log" "net" "reflect" "testing" "time" ) +var logger = log.New(ioutil.Discard, "", 0) + func TestRecord(t *testing.T) { client := testClient() - logger := NewLogger(client, LogAll, 0) + logger := NewLogger(client, LogAll, 0, logger) logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2") // Flush queue if err := logger.Close(); err != nil { @@ -42,7 +46,7 @@ func TestMode(t *testing.T) { {goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard, false}, } for i, tt := range tests { - logger := NewLogger(testClient(), tt.mode, 0) + logger := NewLogger(testClient(), tt.mode, 0, logger) logger.mode = tt.mode logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question) if err := logger.Close(); err != nil { // Flush @@ -59,7 +63,7 @@ func TestMode(t *testing.T) { } func TestAnswerMerging(t *testing.T) { - logger := NewLogger(testClient(), LogAll, 0) + logger := NewLogger(testClient(), LogAll, 0, logger) now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) logger.now = func() time.Time { return now } logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2") @@ -95,7 +99,7 @@ func TestAnswerMerging(t *testing.T) { } func TestLogPruning(t *testing.T) { - logger := NewLogger(testClient(), LogAll, time.Hour) + logger := NewLogger(testClient(), LogAll, time.Hour, logger) defer logger.Close() tt := time.Now() logger.now = func() time.Time { return tt } |