diff options
Diffstat (limited to 'log')
-rw-r--r-- | log/logger.go | 18 | ||||
-rw-r--r-- | log/logger_test.go | 42 |
2 files changed, 58 insertions, 2 deletions
diff --git a/log/logger.go b/log/logger.go index 7fa4001..1a5b2d0 100644 --- a/log/logger.go +++ b/log/logger.go @@ -10,9 +10,19 @@ import ( "github.com/mpolden/zdns/sql" ) +const ( + // ModeDiscard disables logging of DNS requests. + ModeDiscard = iota + // ModeAll logs all DNS requests. + ModeAll + // ModeHijacked only logs hijacked DNS requests. + ModeHijacked +) + // Logger wraps a standard log.Logger and an optional log database. type Logger struct { *log.Logger + mode int queue chan Entry db *sql.Client wg sync.WaitGroup @@ -23,6 +33,7 @@ type Logger struct { // RecordOptions configures recording of DNS requests. type RecordOptions struct { Database string + Mode int TTL time.Duration } @@ -47,6 +58,7 @@ func newLogger(w io.Writer, prefix string, options RecordOptions, interval time. Logger: log.New(w, prefix, 0), queue: make(chan Entry, 100), now: time.Now, + mode: options.Mode, } var err error if options.Database != "" { @@ -97,6 +109,12 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question if l.db == nil { return } + if l.mode == ModeDiscard { + return + } + if l.mode == ModeHijacked && !hijacked { + return + } l.queue <- Entry{ Time: l.now(), RemoteAddr: remoteAddr, diff --git a/log/logger_test.go b/log/logger_test.go index e84ae06..9a58292 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -9,7 +9,7 @@ import ( ) func TestRecord(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:"}) + logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll}) if err != nil { t.Fatal(err) } @@ -27,8 +27,45 @@ func TestRecord(t *testing.T) { } } +func TestMode(t *testing.T) { + badHost := "badhost1." + goodHost := "goodhost1." + var tests = []struct { + question string + remoteAddr net.IP + hijacked bool + mode int + log bool + }{ + {badHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true}, + {goodHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true}, + {badHost, net.IPv4(192, 0, 2, 100), true, ModeHijacked, true}, + {goodHost, net.IPv4(192, 0, 2, 100), false, ModeHijacked, false}, + {badHost, net.IPv4(192, 0, 2, 100), true, ModeDiscard, false}, + {goodHost, net.IPv4(192, 0, 2, 100), false, ModeDiscard, false}, + } + for i, tt := range tests { + logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: tt.mode}) + if err != nil { + t.Fatal(err) + } + logger.mode = tt.mode + logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question) + if err := logger.Close(); err != nil { // Flush + t.Fatal(err) + } + entries, err := logger.Get(1) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 != tt.log { + t.Errorf("#%d: question %q (hijacked=%t) should be logged in mode %d", i, tt.question, tt.hijacked, tt.mode) + } + } +} + func TestAnswerMerging(t *testing.T) { - logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:"}) + logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll}) if err != nil { t.Fatal(err) } @@ -68,6 +105,7 @@ func TestAnswerMerging(t *testing.T) { func TestLogPruning(t *testing.T) { logger, err := newLogger(os.Stderr, "test: ", RecordOptions{ + Mode: ModeAll, Database: ":memory:", TTL: time.Hour, }, 10*time.Millisecond) |