aboutsummaryrefslogtreecommitdiffstats
path: root/log
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-29 19:34:14 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-29 19:34:14 +0100
commit23ce61b1f6ff3b2e6e61162aadf7da6759e9d3df (patch)
treea42579b6c9b2173023f4ddafc9abfe540bde9021 /log
parent72ee46698a94c48527184109401e8a6725a4674b (diff)
Move log mode to Logger
Diffstat (limited to 'log')
-rw-r--r--log/logger.go18
-rw-r--r--log/logger_test.go42
2 files changed, 58 insertions, 2 deletions
diff --git a/log/logger.go b/log/logger.go
index 7fa4001..1a5b2d0 100644
--- a/log/logger.go
+++ b/log/logger.go
@@ -10,9 +10,19 @@ import (
"github.com/mpolden/zdns/sql"
)
+const (
+ // ModeDiscard disables logging of DNS requests.
+ ModeDiscard = iota
+ // ModeAll logs all DNS requests.
+ ModeAll
+ // ModeHijacked only logs hijacked DNS requests.
+ ModeHijacked
+)
+
// Logger wraps a standard log.Logger and an optional log database.
type Logger struct {
*log.Logger
+ mode int
queue chan Entry
db *sql.Client
wg sync.WaitGroup
@@ -23,6 +33,7 @@ type Logger struct {
// RecordOptions configures recording of DNS requests.
type RecordOptions struct {
Database string
+ Mode int
TTL time.Duration
}
@@ -47,6 +58,7 @@ func newLogger(w io.Writer, prefix string, options RecordOptions, interval time.
Logger: log.New(w, prefix, 0),
queue: make(chan Entry, 100),
now: time.Now,
+ mode: options.Mode,
}
var err error
if options.Database != "" {
@@ -97,6 +109,12 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question
if l.db == nil {
return
}
+ if l.mode == ModeDiscard {
+ return
+ }
+ if l.mode == ModeHijacked && !hijacked {
+ return
+ }
l.queue <- Entry{
Time: l.now(),
RemoteAddr: remoteAddr,
diff --git a/log/logger_test.go b/log/logger_test.go
index e84ae06..9a58292 100644
--- a/log/logger_test.go
+++ b/log/logger_test.go
@@ -9,7 +9,7 @@ import (
)
func TestRecord(t *testing.T) {
- logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:"})
+ logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll})
if err != nil {
t.Fatal(err)
}
@@ -27,8 +27,45 @@ func TestRecord(t *testing.T) {
}
}
+func TestMode(t *testing.T) {
+ badHost := "badhost1."
+ goodHost := "goodhost1."
+ var tests = []struct {
+ question string
+ remoteAddr net.IP
+ hijacked bool
+ mode int
+ log bool
+ }{
+ {badHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true},
+ {goodHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true},
+ {badHost, net.IPv4(192, 0, 2, 100), true, ModeHijacked, true},
+ {goodHost, net.IPv4(192, 0, 2, 100), false, ModeHijacked, false},
+ {badHost, net.IPv4(192, 0, 2, 100), true, ModeDiscard, false},
+ {goodHost, net.IPv4(192, 0, 2, 100), false, ModeDiscard, false},
+ }
+ for i, tt := range tests {
+ logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: tt.mode})
+ if err != nil {
+ t.Fatal(err)
+ }
+ logger.mode = tt.mode
+ logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question)
+ if err := logger.Close(); err != nil { // Flush
+ t.Fatal(err)
+ }
+ entries, err := logger.Get(1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(entries) > 0 != tt.log {
+ t.Errorf("#%d: question %q (hijacked=%t) should be logged in mode %d", i, tt.question, tt.hijacked, tt.mode)
+ }
+ }
+}
+
func TestAnswerMerging(t *testing.T) {
- logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:"})
+ logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll})
if err != nil {
t.Fatal(err)
}
@@ -68,6 +105,7 @@ func TestAnswerMerging(t *testing.T) {
func TestLogPruning(t *testing.T) {
logger, err := newLogger(os.Stderr, "test: ", RecordOptions{
+ Mode: ModeAll,
Database: ":memory:",
TTL: time.Hour,
}, 10*time.Millisecond)