diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-12-28 14:37:21 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-12-28 14:41:37 +0100 |
commit | 57ebed75974685799017552693618d001ae5c8a5 (patch) | |
tree | 8e61d05c32c150e9da2353446ef16e45345ba2ff | |
parent | 38d912a0324c9759ed47589d9864b6ca16ae3ef1 (diff) |
Simplify logger maintenance
-rw-r--r-- | log/logger.go | 63 | ||||
-rw-r--r-- | log/logger_test.go | 7 |
2 files changed, 27 insertions, 43 deletions
diff --git a/log/logger.go b/log/logger.go index c6b6f82..8c272eb 100644 --- a/log/logger.go +++ b/log/logger.go @@ -13,18 +13,18 @@ import ( // Logger wraps a standard log.Logger and an optional log database. type Logger struct { *log.Logger - now func() time.Time - queue chan Entry - db *sql.Client - maintainer *maintainer - wg sync.WaitGroup + queue chan Entry + db *sql.Client + wg sync.WaitGroup + done chan bool + interval time.Duration + now func() time.Time } // RecordOptions configures recording of DNS requests. type RecordOptions struct { - Database string - ExpiryInterval time.Duration - TTL time.Duration + Database string + TTL time.Duration } // Entry represents a DNS request log entry. @@ -36,18 +36,13 @@ type Entry struct { Answers []string } -type maintainer struct { - interval time.Duration - ttl time.Duration - done chan bool -} - // New creates a new logger wrapping a standard log.Logger. func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) { logger := &Logger{ - Logger: log.New(w, prefix, 0), - queue: make(chan Entry, 100), - now: time.Now, + Logger: log.New(w, prefix, 0), + queue: make(chan Entry, 100), + now: time.Now, + interval: time.Minute, } var err error if options.Database != "" { @@ -59,38 +54,26 @@ func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) { logger.wg.Add(1) go logger.readQueue() if options.TTL > 0 { - if options.ExpiryInterval <= 0 { - options.ExpiryInterval = time.Minute - } - maintain(logger, options.ExpiryInterval, options.TTL) + logger.wg.Add(1) + logger.done = make(chan bool) + go maintain(logger, options.TTL) } return logger, nil } -func maintain(logger *Logger, interval, ttl time.Duration) { - m := &maintainer{ - interval: interval, - ttl: ttl, - done: make(chan bool), - } - logger.maintainer = m - logger.wg.Add(1) - go m.run(logger) -} - -func (m *maintainer) run(logger *Logger) { - ticker := time.NewTicker(m.interval) +func maintain(logger *Logger, ttl time.Duration) { defer logger.wg.Done() + ticker := time.NewTicker(logger.interval) for { select { + case <-logger.done: + ticker.Stop() + return case <-ticker.C: - t := logger.now().Add(-m.ttl) + t := logger.now().Add(-ttl) if err := logger.db.DeleteLogBefore(t); err != nil { logger.Printf("error deleting log entries before %v: %s", t, err) } - case <-m.done: - ticker.Stop() - return } } } @@ -98,8 +81,8 @@ func (m *maintainer) run(logger *Logger) { // Close consumes any outstanding log requests and closes the logger. func (l *Logger) Close() error { close(l.queue) - if l.maintainer != nil { - l.maintainer.done <- true + if l.done != nil { + l.done <- true } l.wg.Wait() return nil diff --git a/log/logger_test.go b/log/logger_test.go index a057d94..72916b6 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -58,13 +58,14 @@ func TestAnswerMerging(t *testing.T) { func TestLogPruning(t *testing.T) { logger, err := New(os.Stderr, "test: ", RecordOptions{ - Database: ":memory:", - ExpiryInterval: 10 * time.Millisecond, - TTL: time.Hour, + Database: ":memory:", + TTL: time.Hour, }) if err != nil { t.Fatal(err) } + logger.interval = 10 * time.Millisecond + defer logger.Close() tt := time.Now() logger.now = func() time.Time { return tt } logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1") |