aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2020-01-11 13:22:57 +0100
committerMartin Polden <mpolden@mpolden.no>2020-01-11 15:14:15 +0100
commit541ec0a86c0d29c469bc48bce34493daf20a0ee4 (patch)
tree043fdfda105a700ab1b94089e0e2580e0f9e1c8b
parentbf8b103057d1492f62e561a0124c8ad7bce36af9 (diff)
Decouple std logger and query logger
-rw-r--r--cmd/zdns/main.go44
-rw-r--r--config.go8
-rw-r--r--dns/proxy.go29
-rw-r--r--dns/proxy_test.go19
-rw-r--r--http/http.go21
-rw-r--r--http/http_test.go15
-rw-r--r--server.go4
-rw-r--r--server_test.go22
-rw-r--r--signal/signal.go3
-rw-r--r--signal/signal_test.go8
-rw-r--r--sql/logger.go (renamed from log/logger.go)72
-rw-r--r--sql/logger_test.go (renamed from log/logger_test.go)41
12 files changed, 141 insertions, 145 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index d116fff..bcc2934 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -3,6 +3,7 @@ package main
import (
"fmt"
"io"
+ "log"
"os"
"path/filepath"
"sync"
@@ -14,8 +15,8 @@ import (
"github.com/mpolden/zdns/dns"
"github.com/mpolden/zdns/dns/dnsutil"
"github.com/mpolden/zdns/http"
- "github.com/mpolden/zdns/log"
"github.com/mpolden/zdns/signal"
+ "github.com/mpolden/zdns/sql"
)
const (
@@ -75,42 +76,45 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal)
config, err := readConfig(*confFile)
fatal(err)
- // Logger
- logger, err := log.New(out, logPrefix, log.RecordOptions{
- Mode: config.DNS.LogMode,
- Database: config.DNS.LogDatabase,
- TTL: config.DNS.LogTTL,
- })
- fatal(err)
-
- // Signal handling
+ // Logging and signal handling
+ logger := log.New(out, logPrefix, log.Lshortfile)
sigHandler := signal.NewHandler(sig, logger)
- sigHandler.OnClose(logger)
- // Client
- client := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...)
+ // SQL backends
+ var sqlLogger *sql.Logger
+ if config.DNS.LogDatabase != "" {
+ sqlClient, err := sql.New(config.DNS.LogDatabase)
+ fatal(err)
+
+ // Logger
+ sqlLogger = sql.NewLogger(sqlClient, config.DNS.LogMode, config.DNS.LogTTL)
+ sigHandler.OnClose(sqlLogger)
+ }
+
+ // DNS client
+ dnsClient := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...)
// Cache
- var cclient *dnsutil.Client
+ var cacheDNS *dnsutil.Client
if config.DNS.CachePrefetch {
- cclient = client
+ cacheDNS = dnsClient
}
- cache := cache.New(config.DNS.CacheSize, cclient)
+ cache := cache.New(config.DNS.CacheSize, cacheDNS)
// DNS server
- proxy, err := dns.NewProxy(cache, client, logger)
+ proxy, err := dns.NewProxy(cache, dnsClient, logger, sqlLogger)
fatal(err)
sigHandler.OnClose(proxy)
- dnsSrv, err := zdns.NewServer(logger, proxy, config)
+ dnsSrv, err := zdns.NewServer(proxy, config, logger)
fatal(err)
sigHandler.OnReload(dnsSrv)
sigHandler.OnClose(dnsSrv)
-
servers := []server{dnsSrv}
+
// HTTP server
if config.DNS.ListenHTTP != "" {
- httpSrv := http.NewServer(logger, cache, config.DNS.ListenHTTP)
+ httpSrv := http.NewServer(cache, sqlLogger, logger, config.DNS.ListenHTTP)
sigHandler.OnClose(httpSrv)
servers = append(servers, httpSrv)
}
diff --git a/config.go b/config.go
index 3d581f1..b691291 100644
--- a/config.go
+++ b/config.go
@@ -10,7 +10,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/mpolden/zdns/hosts"
- "github.com/mpolden/zdns/log"
+ "github.com/mpolden/zdns/sql"
)
// Config specifies is the zdns configuration parameters.
@@ -180,11 +180,11 @@ func (c *Config) load() error {
}
switch c.DNS.LogModeString {
case "":
- c.DNS.LogMode = log.ModeDiscard
+ c.DNS.LogMode = sql.LogDiscard
case "all":
- c.DNS.LogMode = log.ModeAll
+ c.DNS.LogMode = sql.LogAll
case "hijacked":
- c.DNS.LogMode = log.ModeHijacked
+ c.DNS.LogMode = sql.LogHijacked
default:
return fmt.Errorf("invalid log mode: %s", c.DNS.LogModeString)
}
diff --git a/dns/proxy.go b/dns/proxy.go
index 90164ab..9e7607e 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -2,6 +2,7 @@ package dns
import (
"fmt"
+ "log"
"net"
"strings"
"sync"
@@ -32,27 +33,27 @@ type Handler func(*Request) *Reply
// Proxy represents a DNS proxy.
type Proxy struct {
- Handler Handler
- cache *cache.Cache
- logger logger
- server *dns.Server
- client *dnsutil.Client
- mu sync.RWMutex
+ Handler Handler
+ cache *cache.Cache
+ logger *log.Logger
+ dnsLogger logger
+ server *dns.Server
+ client *dnsutil.Client
+ mu sync.RWMutex
}
type logger interface {
- Print(...interface{})
- Printf(string, ...interface{})
Record(net.IP, bool, uint16, string, ...string)
Close() error
}
// NewProxy creates a new DNS proxy.
-func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger logger) (*Proxy, error) {
+func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger *log.Logger, dnsLogger logger) (*Proxy, error) {
return &Proxy{
- logger: logger,
- cache: cache,
- client: client,
+ logger: logger,
+ dnsLogger: dnsLogger,
+ cache: cache,
+ client: client,
}, nil
}
@@ -129,7 +130,9 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) {
default:
panic(fmt.Sprintf("unexpected remote address type %T", v))
}
- p.logger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...)
+ if p.dnsLogger != nil {
+ p.dnsLogger.Record(ip, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, dnsutil.Answers(msg)...)
+ }
w.WriteMsg(msg)
}
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index 7d2a894..9b6dc35 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -3,6 +3,7 @@ package dns
import (
"fmt"
"io/ioutil"
+ "log"
"net"
"reflect"
"sync"
@@ -12,7 +13,6 @@ import (
"github.com/miekg/dns"
"github.com/mpolden/zdns/cache"
"github.com/mpolden/zdns/dns/dnsutil"
- "github.com/mpolden/zdns/log"
)
type dnsWriter struct{ lastReply *dns.Msg }
@@ -63,15 +63,16 @@ func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Dura
return r.answer, time.Second, nil
}
+type testLogger struct{}
+
+func (l *testLogger) Record(net.IP, bool, uint16, string, ...string) {}
+
+func (l *testLogger) Close() error { return nil }
+
func testProxy(t *testing.T) *Proxy {
- log, err := log.New(ioutil.Discard, "", log.RecordOptions{})
- if err != nil {
- t.Fatal(err)
- }
- if err != nil {
- t.Fatal(err)
- }
- proxy, err := NewProxy(cache.New(0, nil), nil, log)
+ logger := log.New(ioutil.Discard, "", 0)
+ dnsLogger := &testLogger{}
+ proxy, err := NewProxy(cache.New(0, nil), nil, logger, dnsLogger)
if err != nil {
t.Fatal(err)
}
diff --git a/http/http.go b/http/http.go
index ac2e9b0..3887f9b 100644
--- a/http/http.go
+++ b/http/http.go
@@ -2,6 +2,7 @@ package http
import (
"context"
+ "log"
"net"
"net/http"
_ "net/http/pprof" // Registers debug handlers as a side effect.
@@ -10,15 +11,16 @@ import (
"github.com/mpolden/zdns/cache"
"github.com/mpolden/zdns/dns/dnsutil"
- "github.com/mpolden/zdns/log"
+ "github.com/mpolden/zdns/sql"
)
// A Server defines paramaters for running an HTTP server. The HTTP server serves an API for inspecting cache contents
// and request log.
type Server struct {
- cache *cache.Cache
- logger *log.Logger
- server *http.Server
+ cache *cache.Cache
+ logger *log.Logger
+ sqlLogger *sql.Logger
+ server *http.Server
}
type entry struct {
@@ -39,12 +41,13 @@ type httpError struct {
}
// NewServer creates a new HTTP server, serving logs from the given logger and listening on addr.
-func NewServer(logger *log.Logger, cache *cache.Cache, addr string) *Server {
+func NewServer(cache *cache.Cache, sqlLogger *sql.Logger, logger *log.Logger, addr string) *Server {
server := &http.Server{Addr: addr}
s := &Server{
- cache: cache,
- logger: logger,
- server: server,
+ cache: cache,
+ logger: logger,
+ sqlLogger: sqlLogger,
+ server: server,
}
s.server.Handler = s.handler()
return s
@@ -94,7 +97,7 @@ func (s *Server) cacheResetHandler(w http.ResponseWriter, r *http.Request) (inte
}
func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) {
- logEntries, err := s.logger.Get(listCountFrom(r))
+ logEntries, err := s.sqlLogger.Get(listCountFrom(r))
if err != nil {
return nil, &httpError{
err: err,
diff --git a/http/http_test.go b/http/http_test.go
index 0d43b46..7853b16 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -2,6 +2,7 @@ package http
import (
"io/ioutil"
+ "log"
"net"
"net/http"
"net/http/httptest"
@@ -11,7 +12,7 @@ import (
"github.com/miekg/dns"
"github.com/mpolden/zdns/cache"
- "github.com/mpolden/zdns/log"
+ "github.com/mpolden/zdns/sql"
)
func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
@@ -30,12 +31,14 @@ func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
}
func testServer() (*httptest.Server, *Server) {
- logger, err := log.New(ioutil.Discard, "", log.RecordOptions{Mode: log.ModeAll, Database: ":memory:"})
+ db, err := sql.New(":memory:")
if err != nil {
panic(err)
}
+ stdLogger := log.New(ioutil.Discard, "", 0)
+ logger := sql.NewLogger(db, sql.LogAll, 0)
cache := cache.New(10, nil)
- server := Server{logger: logger, cache: cache}
+ server := Server{logger: stdLogger, sqlLogger: logger, cache: cache}
return httptest.NewServer(server.handler()), &server
}
@@ -76,9 +79,9 @@ func httpDelete(url, body string) (*http.Response, string, error) {
func TestRequests(t *testing.T) {
httpSrv, srv := testServer()
defer httpSrv.Close()
- srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101")
- srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1")
- srv.logger.Close() // Flush
+ srv.sqlLogger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101")
+ srv.sqlLogger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1")
+ srv.sqlLogger.Close() // Flush
srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200)))
srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201)))
diff --git a/server.go b/server.go
index 79a2447..5fc36dd 100644
--- a/server.go
+++ b/server.go
@@ -3,6 +3,7 @@ package zdns
import (
"fmt"
"io"
+ "log"
"net"
"net/http"
"net/url"
@@ -13,7 +14,6 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/mpolden/zdns/dns"
"github.com/mpolden/zdns/hosts"
- "github.com/mpolden/zdns/log"
)
const (
@@ -37,7 +37,7 @@ type Server struct {
}
// NewServer returns a new server configured according to config.
-func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, error) {
+func NewServer(proxy *dns.Proxy, config Config, logger *log.Logger) (*Server, error) {
server := &Server{
Config: config,
done: make(chan bool, 1),
diff --git a/server_test.go b/server_test.go
index 6aa16ad..73fc35b 100644
--- a/server_test.go
+++ b/server_test.go
@@ -2,6 +2,7 @@ package zdns
import (
"io/ioutil"
+ "log"
"net"
"net/http"
"net/http/httptest"
@@ -13,7 +14,6 @@ import (
"github.com/mpolden/zdns/cache"
"github.com/mpolden/zdns/dns"
"github.com/mpolden/zdns/hosts"
- "github.com/mpolden/zdns/log"
)
const hostsFile1 = `
@@ -29,6 +29,12 @@ const hostsFile2 = `
192.0.2.6 badhost6
`
+type testLogger struct{}
+
+func (l *testLogger) Record(net.IP, bool, uint16, string, ...string) {}
+
+func (l *testLogger) Close() error { return nil }
+
func httpHandler(t *testing.T, response string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte(response)); err != nil {
@@ -96,15 +102,13 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) {
if err := config.load(); err != nil {
t.Fatal(err)
}
- logger, err := log.New(ioutil.Discard, "", log.RecordOptions{})
- if err != nil {
- t.Fatal(err)
- }
- proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger)
+ logger := log.New(ioutil.Discard, "", 0)
+ queryLogger := &testLogger{}
+ proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger, queryLogger)
if err != nil {
t.Fatal(err)
}
- srv, err = NewServer(logger, proxy, config)
+ srv, err = NewServer(proxy, config, logger)
if err != nil {
defer cleanup()
t.Fatal(err)
@@ -168,7 +172,7 @@ func TestNonFqdn(t *testing.T) {
}
func TestHijack(t *testing.T) {
- log, _ := log.New(ioutil.Discard, "", log.RecordOptions{})
+ logger := log.New(ioutil.Discard, "", 0)
s := &Server{
Config: Config{},
hosts: hosts.Hosts{
@@ -177,7 +181,7 @@ func TestHijack(t *testing.T) {
{IP: net.ParseIP("2001:db8::1")},
},
},
- logger: log,
+ logger: logger,
}
var tests = []struct {
diff --git a/signal/signal.go b/signal/signal.go
index 96ef5f7..d2f50c2 100644
--- a/signal/signal.go
+++ b/signal/signal.go
@@ -2,11 +2,10 @@ package signal
import (
"io"
+ "log"
"os"
"os/signal"
"syscall"
-
- "github.com/mpolden/zdns/log"
)
// Reloader is the interface for types that need to act on a reload signal.
diff --git a/signal/signal_test.go b/signal/signal_test.go
index 7010ebd..fe0cef5 100644
--- a/signal/signal_test.go
+++ b/signal/signal_test.go
@@ -2,13 +2,12 @@ package signal
import (
"io/ioutil"
+ "log"
"os"
"sync"
"syscall"
"testing"
"time"
-
- "github.com/mpolden/zdns/log"
)
type reloaderCloser struct {
@@ -47,10 +46,7 @@ func (rc *reloaderCloser) reset() {
}
func TestHandler(t *testing.T) {
- logger, err := log.New(ioutil.Discard, "", log.RecordOptions{})
- if err != nil {
- t.Fatal(err)
- }
+ logger := log.New(ioutil.Discard, "", 0)
h := NewHandler(make(chan os.Signal, 1), logger)
rc := &reloaderCloser{}
diff --git a/log/logger.go b/sql/logger.go
index d8dc4af..f6c7215 100644
--- a/log/logger.go
+++ b/sql/logger.go
@@ -1,32 +1,29 @@
-package log
+package sql
import (
- "io"
"log"
"net"
"sync"
"time"
-
- "github.com/mpolden/zdns/sql"
)
const (
- // ModeDiscard disables logging of DNS requests.
- ModeDiscard = iota
- // ModeAll logs all DNS requests.
- ModeAll
- // ModeHijacked only logs hijacked DNS requests.
- ModeHijacked
+ // LogDiscard disables logging of DNS requests.
+ LogDiscard = iota
+ // LogAll logs all DNS requests.
+ LogAll
+ // LogHijacked only logs hijacked DNS requests.
+ LogHijacked
)
-// Logger wraps a standard log.Logger and an optional log database.
+// Logger is a logs DNS requests to a SQL database.
type Logger struct {
- *log.Logger
- mode int
- queue chan Entry
- db *sql.Client
- wg sync.WaitGroup
- now func() time.Time
+ mode int
+ queue chan Entry
+ db *Client
+ wg sync.WaitGroup
+ now func() time.Time
+ Logger *log.Logger
}
// RecordOptions configures recording of DNS requests.
@@ -46,25 +43,19 @@ type Entry struct {
Answers []string
}
-// New creates a new logger, writing log output to writer w prefixed with prefix. Persisted logging behaviour is
-// controller by options.
-func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) {
+// NewLogger creates a new logger. Persisted entries are kept according to ttl.
+func NewLogger(db *Client, mode int, ttl time.Duration) *Logger {
logger := &Logger{
- Logger: log.New(w, prefix, 0),
- queue: make(chan Entry, 100),
- now: time.Now,
- mode: options.Mode,
+ db: db,
+ queue: make(chan Entry, 100),
+ now: time.Now,
+ mode: mode,
}
- var err error
- if options.Database != "" {
- logger.db, err = sql.New(options.Database)
- if err != nil {
- return nil, err
- }
+ if mode != LogDiscard {
+ logger.wg.Add(1)
+ go logger.readQueue(ttl)
}
- logger.wg.Add(1)
- go logger.readQueue(options.TTL)
- return logger, nil
+ return logger
}
// Close consumes any outstanding log requests and closes the logger.
@@ -79,10 +70,10 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question
if l.db == nil {
return
}
- if l.mode == ModeDiscard {
+ if l.mode == LogDiscard {
return
}
- if l.mode == ModeHijacked && !hijacked {
+ if l.mode == LogHijacked && !hijacked {
return
}
l.queue <- Entry{
@@ -124,16 +115,23 @@ func (l *Logger) Get(n int) ([]Entry, error) {
return entries, nil
}
+func (l *Logger) printf(format string, v ...interface{}) {
+ if l.Logger == nil {
+ return
+ }
+ l.Logger.Printf(format, v...)
+}
+
func (l *Logger) readQueue(ttl time.Duration) {
defer l.wg.Done()
for e := range l.queue {
if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
- l.Printf("write failed: %+v: %s", e, err)
+ l.printf("write failed: %+v: %s", e, err)
}
if ttl > 0 {
t := l.now().Add(-ttl)
if err := l.db.DeleteLogBefore(t); err != nil {
- l.Printf("deleting log entries before %v failed: %s", t, err)
+ l.printf("deleting log entries before %v failed: %s", t, err)
}
}
}
diff --git a/log/logger_test.go b/sql/logger_test.go
index f3a91f1..fea858b 100644
--- a/log/logger_test.go
+++ b/sql/logger_test.go
@@ -1,18 +1,15 @@
-package log
+package sql
import (
"net"
- "os"
"reflect"
"testing"
"time"
)
func TestRecord(t *testing.T) {
- logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll})
- if err != nil {
- t.Fatal(err)
- }
+ client := testClient()
+ logger := NewLogger(client, LogAll, 0)
logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2")
// Flush queue
if err := logger.Close(); err != nil {
@@ -37,18 +34,15 @@ func TestMode(t *testing.T) {
mode int
log bool
}{
- {badHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true},
- {goodHost, net.IPv4(192, 0, 2, 100), true, ModeAll, true},
- {badHost, net.IPv4(192, 0, 2, 100), true, ModeHijacked, true},
- {goodHost, net.IPv4(192, 0, 2, 100), false, ModeHijacked, false},
- {badHost, net.IPv4(192, 0, 2, 100), true, ModeDiscard, false},
- {goodHost, net.IPv4(192, 0, 2, 100), false, ModeDiscard, false},
+ {badHost, net.IPv4(192, 0, 2, 100), true, LogAll, true},
+ {goodHost, net.IPv4(192, 0, 2, 100), true, LogAll, true},
+ {badHost, net.IPv4(192, 0, 2, 100), true, LogHijacked, true},
+ {goodHost, net.IPv4(192, 0, 2, 100), false, LogHijacked, false},
+ {badHost, net.IPv4(192, 0, 2, 100), true, LogDiscard, false},
+ {goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard, false},
}
for i, tt := range tests {
- logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: tt.mode})
- if err != nil {
- t.Fatal(err)
- }
+ logger := NewLogger(testClient(), tt.mode, 0)
logger.mode = tt.mode
logger.Record(tt.remoteAddr, tt.hijacked, 1, tt.question)
if err := logger.Close(); err != nil { // Flush
@@ -65,10 +59,7 @@ func TestMode(t *testing.T) {
}
func TestAnswerMerging(t *testing.T) {
- logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:", Mode: ModeAll})
- if err != nil {
- t.Fatal(err)
- }
+ logger := NewLogger(testClient(), LogAll, 0)
now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC)
logger.now = func() time.Time { return now }
logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2")
@@ -104,14 +95,7 @@ func TestAnswerMerging(t *testing.T) {
}
func TestLogPruning(t *testing.T) {
- logger, err := New(os.Stderr, "test: ", RecordOptions{
- Mode: ModeAll,
- Database: ":memory:",
- TTL: time.Hour,
- })
- if err != nil {
- t.Fatal(err)
- }
+ logger := NewLogger(testClient(), LogAll, time.Hour)
defer logger.Close()
tt := time.Now()
logger.now = func() time.Time { return tt }
@@ -120,6 +104,7 @@ func TestLogPruning(t *testing.T) {
// Wait until queue is flushed
ts := time.Now()
var entries []Entry
+ var err error
for len(entries) == 0 {
entries, err = logger.Get(1)
if err != nil {