aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cmd/zdns/main.go2
-rw-r--r--http/http_test.go2
-rw-r--r--sql/logger.go32
-rw-r--r--sql/logger_test.go12
4 files changed, 23 insertions, 25 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index 0fc935d..87b3994 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -90,7 +90,7 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal)
fatal(err)
// Logger
- sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL)
+ sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL, logger)
sigHandler.OnClose(sqlLogger)
// Cache
diff --git a/http/http_test.go b/http/http_test.go
index 7853b16..fa39549 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -36,7 +36,7 @@ func testServer() (*httptest.Server, *Server) {
panic(err)
}
stdLogger := log.New(ioutil.Discard, "", 0)
- logger := sql.NewLogger(db, sql.LogAll, 0)
+ logger := sql.NewLogger(db, sql.LogAll, 0, stdLogger)
cache := cache.New(10, nil)
server := Server{logger: stdLogger, sqlLogger: logger, cache: cache}
return httptest.NewServer(server.handler()), &server
diff --git a/sql/logger.go b/sql/logger.go
index cca3de8..e226103 100644
--- a/sql/logger.go
+++ b/sql/logger.go
@@ -23,7 +23,7 @@ type Logger struct {
db *Client
wg sync.WaitGroup
now func() time.Time
- Logger *log.Logger
+ logger *log.Logger
}
// Entry represents a DNS request log entry.
@@ -37,18 +37,19 @@ type Entry struct {
}
// NewLogger creates a new logger. Persisted entries are kept according to ttl.
-func NewLogger(db *Client, mode int, ttl time.Duration) *Logger {
- logger := &Logger{
- db: db,
- queue: make(chan Entry, 100),
- now: time.Now,
- mode: mode,
+func NewLogger(db *Client, mode int, ttl time.Duration, logger *log.Logger) *Logger {
+ l := &Logger{
+ db: db,
+ queue: make(chan Entry, 100),
+ now: time.Now,
+ mode: mode,
+ logger: logger,
}
if mode != LogDiscard {
- logger.wg.Add(1)
- go logger.readQueue(ttl)
+ l.wg.Add(1)
+ go l.readQueue(ttl)
}
- return logger
+ return l
}
// Close consumes any outstanding log requests and closes the logger.
@@ -108,23 +109,16 @@ func (l *Logger) Get(n int) ([]Entry, error) {
return entries, nil
}
-func (l *Logger) printf(format string, v ...interface{}) {
- if l.Logger == nil {
- return
- }
- l.Logger.Printf(format, v...)
-}
-
func (l *Logger) readQueue(ttl time.Duration) {
defer l.wg.Done()
for e := range l.queue {
if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
- l.printf("write failed: %+v: %s", e, err)
+ l.logger.Printf("write failed: %+v: %s", e, err)
}
if ttl > 0 {
t := l.now().Add(-ttl)
if err := l.db.DeleteLogBefore(t); err != nil {
- l.printf("deleting log entries before %v failed: %s", t, err)
+ l.logger.Printf("deleting log entries before %v failed: %s", t, err)
}
}
}
diff --git a/sql/logger_test.go b/sql/logger_test.go
index fea858b..776de31 100644
--- a/sql/logger_test.go
+++ b/sql/logger_test.go
@@ -1,15 +1,19 @@
package sql
import (
+ "io/ioutil"
+ "log"
"net"
"reflect"
"testing"
"time"
)
+var logger = log.New(ioutil.Discard, "", 0)
+
func TestRecord(t *testing.T) {
client := testClient()
- logger := NewLogger(client, LogAll, 0)
+ logger := NewLogger(client, LogAll, 0, logger)
logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2")
// Flush queue
if err := logger.Close(); err != nil {
@@ -42,7 +46,7 @@ func TestMode(t *testing.T) {
{goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard, false},
}
for i, tt := range tests {
- logger := NewLogger(testClient(), tt.mode, 0)
+ logger := NewLogger(testClient(), tt.mode, 0, logger)
logger.mode = tt.mode
logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question)
if err := logger.Close(); err != nil { // Flush
@@ -59,7 +63,7 @@ func TestMode(t *testing.T) {
}
func TestAnswerMerging(t *testing.T) {
- logger := NewLogger(testClient(), LogAll, 0)
+ logger := NewLogger(testClient(), LogAll, 0, logger)
now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC)
logger.now = func() time.Time { return now }
logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2")
@@ -95,7 +99,7 @@ func TestAnswerMerging(t *testing.T) {
}
func TestLogPruning(t *testing.T) {
- logger := NewLogger(testClient(), LogAll, time.Hour)
+ logger := NewLogger(testClient(), LogAll, time.Hour, logger)
defer logger.Close()
tt := time.Now()
logger.now = func() time.Time { return tt }