aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-28 14:37:21 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-28 14:41:37 +0100
commit57ebed75974685799017552693618d001ae5c8a5 (patch)
tree8e61d05c32c150e9da2353446ef16e45345ba2ff
parent38d912a0324c9759ed47589d9864b6ca16ae3ef1 (diff)
Simplify logger maintenance
-rw-r--r--log/logger.go63
-rw-r--r--log/logger_test.go7
2 files changed, 27 insertions, 43 deletions
diff --git a/log/logger.go b/log/logger.go
index c6b6f82..8c272eb 100644
--- a/log/logger.go
+++ b/log/logger.go
@@ -13,18 +13,18 @@ import (
// Logger wraps a standard log.Logger and an optional log database.
type Logger struct {
*log.Logger
- now func() time.Time
- queue chan Entry
- db *sql.Client
- maintainer *maintainer
- wg sync.WaitGroup
+ queue chan Entry
+ db *sql.Client
+ wg sync.WaitGroup
+ done chan bool
+ interval time.Duration
+ now func() time.Time
}
// RecordOptions configures recording of DNS requests.
type RecordOptions struct {
- Database string
- ExpiryInterval time.Duration
- TTL time.Duration
+ Database string
+ TTL time.Duration
}
// Entry represents a DNS request log entry.
@@ -36,18 +36,13 @@ type Entry struct {
Answers []string
}
-type maintainer struct {
- interval time.Duration
- ttl time.Duration
- done chan bool
-}
-
// New creates a new logger wrapping a standard log.Logger.
func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) {
logger := &Logger{
- Logger: log.New(w, prefix, 0),
- queue: make(chan Entry, 100),
- now: time.Now,
+ Logger: log.New(w, prefix, 0),
+ queue: make(chan Entry, 100),
+ now: time.Now,
+ interval: time.Minute,
}
var err error
if options.Database != "" {
@@ -59,38 +54,26 @@ func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) {
logger.wg.Add(1)
go logger.readQueue()
if options.TTL > 0 {
- if options.ExpiryInterval <= 0 {
- options.ExpiryInterval = time.Minute
- }
- maintain(logger, options.ExpiryInterval, options.TTL)
+ logger.wg.Add(1)
+ logger.done = make(chan bool)
+ go maintain(logger, options.TTL)
}
return logger, nil
}
-func maintain(logger *Logger, interval, ttl time.Duration) {
- m := &maintainer{
- interval: interval,
- ttl: ttl,
- done: make(chan bool),
- }
- logger.maintainer = m
- logger.wg.Add(1)
- go m.run(logger)
-}
-
-func (m *maintainer) run(logger *Logger) {
- ticker := time.NewTicker(m.interval)
+func maintain(logger *Logger, ttl time.Duration) {
defer logger.wg.Done()
+ ticker := time.NewTicker(logger.interval)
for {
select {
+ case <-logger.done:
+ ticker.Stop()
+ return
case <-ticker.C:
- t := logger.now().Add(-m.ttl)
+ t := logger.now().Add(-ttl)
if err := logger.db.DeleteLogBefore(t); err != nil {
logger.Printf("error deleting log entries before %v: %s", t, err)
}
- case <-m.done:
- ticker.Stop()
- return
}
}
}
@@ -98,8 +81,8 @@ func (m *maintainer) run(logger *Logger) {
// Close consumes any outstanding log requests and closes the logger.
func (l *Logger) Close() error {
close(l.queue)
- if l.maintainer != nil {
- l.maintainer.done <- true
+ if l.done != nil {
+ l.done <- true
}
l.wg.Wait()
return nil
diff --git a/log/logger_test.go b/log/logger_test.go
index a057d94..72916b6 100644
--- a/log/logger_test.go
+++ b/log/logger_test.go
@@ -58,13 +58,14 @@ func TestAnswerMerging(t *testing.T) {
func TestLogPruning(t *testing.T) {
logger, err := New(os.Stderr, "test: ", RecordOptions{
- Database: ":memory:",
- ExpiryInterval: 10 * time.Millisecond,
- TTL: time.Hour,
+ Database: ":memory:",
+ TTL: time.Hour,
})
if err != nil {
t.Fatal(err)
}
+ logger.interval = 10 * time.Millisecond
+ defer logger.Close()
tt := time.Now()
logger.now = func() time.Time { return tt }
logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1")