diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-12-26 13:14:00 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-12-26 13:14:00 +0100 |
commit | 69963c803f2bb28437a997646ab60214de5151ea (patch) | |
tree | aeaddc0b83afb7e94e72c7e9d7d5051396a43d81 | |
parent | e77437a045ffc4e99665646a46326bd64f49b6a7 (diff) |
Decouple dependencies
-rw-r--r-- | cache/cache.go | 7 | ||||
-rw-r--r-- | cache/cache_test.go | 20 | ||||
-rw-r--r-- | cmd/zdns/main.go | 19 | ||||
-rw-r--r-- | cmd/zdns/main_test.go | 2 | ||||
-rw-r--r-- | config.go | 70 | ||||
-rw-r--r-- | config_test.go | 6 | ||||
-rw-r--r-- | dns/proxy.go | 29 | ||||
-rw-r--r-- | dns/proxy_test.go | 41 | ||||
-rw-r--r-- | server.go | 22 | ||||
-rw-r--r-- | server_test.go | 15 |
10 files changed, 109 insertions, 122 deletions
diff --git a/cache/cache.go b/cache/cache.go index 9a5d72a..bf7cfae 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -2,7 +2,6 @@ package cache import ( "encoding/binary" - "fmt" "hash/fnv" "net" "sync" @@ -64,9 +63,9 @@ type Value struct { func (v *Value) TTL() time.Duration { return minTTL(v.msg) } // New creates a new cache of given capacity. Stale cache entries are removed at expiryInterval. -func New(capacity int, expiryInterval time.Duration) (*Cache, error) { +func New(capacity int, expiryInterval time.Duration) *Cache { if capacity < 0 { - return nil, fmt.Errorf("invalid capacity: %d", capacity) + capacity = 0 } if expiryInterval == 0 { expiryInterval = 10 * time.Minute @@ -77,7 +76,7 @@ func New(capacity int, expiryInterval time.Duration) (*Cache, error) { entries: make(map[uint32]*Value, capacity), } maintain(cache, expiryInterval) - return cache, nil + return cache } // NewKey creates a new cache key for the DNS name, qtype and qclass diff --git a/cache/cache_test.go b/cache/cache_test.go index f0d61e6..11936eb 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -84,10 +84,7 @@ func TestCache(t *testing.T) { msgFailure.Rcode = dns.RcodeServerFailure tt := date(2019, 1, 1) - c, err := New(100, time.Duration(10*time.Millisecond)) - if err != nil { - t.Fatal(err) - } + c := New(100, time.Duration(10*time.Millisecond)) defer handleErr(t, c.Close) var tests = []struct { msg *dns.Msg @@ -129,10 +126,7 @@ func TestCacheCapacity(t *testing.T) { {3, 2, 2}, } for i, tt := range tests { - c, err := New(tt.capacity, 10*time.Minute) - if err != nil { - t.Fatal(err) - } + c := New(tt.capacity, 10*time.Minute) defer handleErr(t, c.Close) var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { @@ -170,10 +164,7 @@ func TestCacheList(t *testing.T) { {2, 3, 2}, } for i, tt := range tests { - c, err := New(1024, 10*time.Minute) - if err != nil { - t.Fatal(err) - } + c := New(1024, 10*time.Minute) defer handleErr(t, c.Close) var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { @@ -207,10 +198,7 @@ func BenchmarkNewKey(b *testing.B) { } func BenchmarkCache(b *testing.B) { - c, err := New(1000, 10*time.Minute) - if err != nil { - b.Fatal(err) - } + c := New(1000, 10*time.Minute) b.ResetTimer() for n := 0; n < b.N; n++ { c.Set(uint32(n), &dns.Msg{}) diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 623175c..0a9dacf 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -10,6 +10,8 @@ import ( "flag" "github.com/mpolden/zdns" + "github.com/mpolden/zdns/cache" + "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/http" "github.com/mpolden/zdns/log" "github.com/mpolden/zdns/signal" @@ -71,18 +73,33 @@ func (c *cli) run() { return } + // Config config, err := readConfig(*confFile) fatal(err) + // Logger logger, err := log.New(c.out, logPrefix, log.RecordOptions{ Database: config.DNS.LogDatabase, TTL: config.DNS.LogTTL, }) fatal(err) + // Signal handling sigHandler := signal.NewHandler(c.signal, logger) - dnsSrv, err := zdns.NewServer(logger, config) + // Cache + cache := cache.New(config.DNS.CacheSize, config.DNS.CacheExpiryInterval) + + // DNS server + proxy, err := dns.NewProxy(cache, logger, dns.ProxyOptions{ + Resolvers: config.DNS.Resolvers, + LogMode: config.DNS.LogMode, + Network: config.Resolver.Protocol, + Timeout: config.Resolver.Timeout, + }) + fatal(err) + + dnsSrv, err := zdns.NewServer(logger, proxy, config) fatal(err) sigHandler.OnReload(dnsSrv) sigHandler.OnClose(dnsSrv) diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go index f2dfa25..83bf1aa 100644 --- a/cmd/zdns/main_test.go +++ b/cmd/zdns/main_test.go @@ -29,7 +29,7 @@ func tempFile(t *testing.T, s string) (string, error) { func TestMain(t *testing.T) { conf := ` [dns] -listen = "0.0.0.0:0" +listen = "127.0.0.1:0" [resolver] protocol = "udp" @@ -22,29 +22,29 @@ type Config struct { // DNSOptions controlers the behaviour of the DNS server. type DNSOptions struct { - Listen string - Protocol string `toml:"protocol"` - CacheExpiryInterval string `toml:"cache_expiry_interval"` - cacheExpiryInterval time.Duration - CacheSize int `toml:"cache_size"` - HijackMode string `toml:"hijack_mode"` - hijackMode int - RefreshInterval string `toml:"hosts_refresh_interval"` - refreshInterval time.Duration - Resolvers []string - LogDatabase string `toml:"log_database"` - LogMode string `toml:"log_mode"` - logMode int - LogTTLString string `toml:"log_ttl"` - LogTTL time.Duration - ListenHTTP string `toml:"listen_http"` + Listen string + Protocol string `toml:"protocol"` + CacheExpiryIntervalString string `toml:"cache_expiry_interval"` + CacheExpiryInterval time.Duration + CacheSize int `toml:"cache_size"` + HijackMode string `toml:"hijack_mode"` + hijackMode int + RefreshInterval string `toml:"hosts_refresh_interval"` + refreshInterval time.Duration + Resolvers []string + LogDatabase string `toml:"log_database"` + LogModeString string `toml:"log_mode"` + LogMode int + LogTTLString string `toml:"log_ttl"` + LogTTL time.Duration + ListenHTTP string `toml:"listen_http"` } // ResolverOptions controls the behaviour of resolvers. type ResolverOptions struct { - Protocol string `toml:"protocol"` - Timeout string `toml:"timeout"` - timeout time.Duration + Protocol string `toml:"protocol"` + TimeoutString string `toml:"timeout"` + Timeout time.Duration } // Hosts controls how a hosts file should be retrieved. @@ -63,7 +63,7 @@ func newConfig() Config { c.DNS.Protocol = "udp" c.DNS.CacheSize = 1024 c.DNS.RefreshInterval = "48h" - c.Resolver.Timeout = "5s" + c.Resolver.TimeoutString = "5s" c.Resolver.Protocol = "udp" return c } @@ -82,10 +82,10 @@ func (c *Config) load() error { if c.DNS.CacheSize < 0 { return fmt.Errorf("cache size must be >= 0") } - if c.DNS.CacheExpiryInterval == "" { - c.DNS.CacheExpiryInterval = "15m" + if c.DNS.CacheExpiryIntervalString == "" { + c.DNS.CacheExpiryIntervalString = "15m" } - c.DNS.cacheExpiryInterval, err = time.ParseDuration(c.DNS.CacheExpiryInterval) + c.DNS.CacheExpiryInterval, err = time.ParseDuration(c.DNS.CacheExpiryIntervalString) if err != nil { return fmt.Errorf("invalid cache expiry interval: %s", err) } @@ -159,28 +159,28 @@ func (c *Config) load() error { default: return fmt.Errorf("invalid resolver protocol: %s", c.Resolver.Protocol) } - c.Resolver.timeout, err = time.ParseDuration(c.Resolver.Timeout) + c.Resolver.Timeout, err = time.ParseDuration(c.Resolver.TimeoutString) if err != nil { - return fmt.Errorf("invalid resolver timeout: %s", c.Resolver.Timeout) + return fmt.Errorf("invalid resolver timeout: %s", c.Resolver.TimeoutString) } - if c.Resolver.timeout < 0 { + if c.Resolver.Timeout < 0 { return fmt.Errorf("resolver timeout must be >= 0") } - if c.Resolver.timeout == 0 { - c.Resolver.timeout = 5 * time.Second + if c.Resolver.Timeout == 0 { + c.Resolver.Timeout = 5 * time.Second } - switch c.DNS.LogMode { + switch c.DNS.LogModeString { case "": - c.DNS.logMode = dns.LogDiscard + c.DNS.LogMode = dns.LogDiscard case "all": - c.DNS.logMode = dns.LogAll + c.DNS.LogMode = dns.LogAll case "hijacked": - c.DNS.logMode = dns.LogHijacked + c.DNS.LogMode = dns.LogHijacked default: - return fmt.Errorf("invalid log mode: %s", c.DNS.LogMode) + return fmt.Errorf("invalid log mode: %s", c.DNS.LogModeString) } - if c.DNS.LogMode != "" && c.DNS.LogDatabase == "" { - return fmt.Errorf("log_mode = %q requires log_database to be set", c.DNS.LogMode) + if c.DNS.LogModeString != "" && c.DNS.LogDatabase == "" { + return fmt.Errorf("log_mode = %q requires log_database to be set", c.DNS.LogModeString) } if c.DNS.LogTTLString == "" { c.DNS.LogTTLString = "0" diff --git a/config_test.go b/config_test.go index cf2e67e..12b8b6e 100644 --- a/config_test.go +++ b/config_test.go @@ -56,9 +56,9 @@ hijack = false want int }{ {"DNS.CacheSize", conf.DNS.CacheSize, 2048}, - {"DNS.CacheExpiryInterval", int(conf.DNS.cacheExpiryInterval), int(5 * time.Minute)}, + {"DNS.CacheExpiryInterval", int(conf.DNS.CacheExpiryInterval), int(5 * time.Minute)}, {"len(DNS.Resolvers)", len(conf.DNS.Resolvers), 2}, - {"Resolver.Timeout", int(conf.Resolver.timeout), int(time.Second)}, + {"Resolver.Timeout", int(conf.Resolver.Timeout), int(time.Second)}, {"DNS.RefreshInterval", int(conf.DNS.refreshInterval), int(48 * time.Hour)}, {"len(Hosts)", len(conf.Hosts), 3}, {"DNS.LogTTL", int(conf.DNS.LogTTL), int(72 * time.Hour)}, @@ -80,7 +80,7 @@ hijack = false {"DNS.Resolvers[1]", conf.DNS.Resolvers[1], "192.0.2.2:53"}, {"DNS.HijackMode", conf.DNS.HijackMode, "zero"}, {"DNS.LogDatabase", conf.DNS.LogDatabase, "/tmp/log.db"}, - {"DNS.LogMode", conf.DNS.LogMode, "all"}, + {"DNS.LogMode", conf.DNS.LogModeString, "all"}, {"DNS.LogTTL", conf.DNS.LogTTLString, "72h"}, {"Resolver.Protocol", conf.Resolver.Protocol, "tcp-tls"}, {"Hosts[0].Source", conf.Hosts[0].URL, "file:///home/foo/hosts-good"}, diff --git a/dns/proxy.go b/dns/proxy.go index ad81c6d..5ea23ed 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -36,7 +36,7 @@ type Handler func(*Request) *Reply // Proxy represents a DNS proxy. type Proxy struct { - handler Handler + Handler Handler resolvers []string cache *cache.Cache logger logger @@ -47,14 +47,10 @@ type Proxy struct { // ProxyOptions represents proxy configuration. type ProxyOptions struct { - Handler Handler - Resolvers []string - Logger logger - LogMode int - Network string - Timeout time.Duration - CacheSize int - CacheExpiryInterval time.Duration + Resolvers []string + LogMode int + Network string + Timeout time.Duration } type client interface { @@ -68,17 +64,12 @@ type logger interface { } // NewProxy creates a new DNS proxy. -func NewProxy(options ProxyOptions) (*Proxy, error) { - cache, err := cache.New(options.CacheSize, options.CacheExpiryInterval) - if err != nil { - return nil, err - } +func NewProxy(cache *cache.Cache, logger logger, options ProxyOptions) (*Proxy, error) { return &Proxy{ - handler: options.Handler, + logger: logger, + cache: cache, resolvers: options.Resolvers, - logger: options.Logger, logMode: options.LogMode, - cache: cache, client: &dns.Client{Net: options.Network, Timeout: options.Timeout}, }, nil } @@ -119,10 +110,10 @@ func (r *Reply) String() string { } func (p *Proxy) reply(r *dns.Msg) *dns.Msg { - if p.handler == nil || len(r.Question) != 1 { + if p.Handler == nil || len(r.Question) != 1 { return nil } - reply := p.handler(&Request{ + reply := p.Handler(&Request{ Name: r.Question[0].Name, Type: r.Question[0].Qtype, }) diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 2727b8e..41237a6 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -60,15 +60,15 @@ func (l *testLogger) Record(remoteAddr net.IP, qtype uint16, question, answer st l.remoteAddr = remoteAddr } -func testProxy(t *testing.T) *Proxy { return testProxyWithOptions(t, ProxyOptions{}) } - -func testProxyWithOptions(t *testing.T, options ProxyOptions) *Proxy { +func testProxy(t *testing.T) *Proxy { log, err := log.New(ioutil.Discard, "", log.RecordOptions{}) if err != nil { t.Fatal(err) } - options.Logger = log - proxy, err := NewProxy(options) + if err != nil { + t.Fatal(err) + } + proxy, err := NewProxy(cache.New(0, time.Minute), log, ProxyOptions{}) if err != nil { t.Fatal(err) } @@ -143,7 +143,7 @@ func TestProxy(t *testing.T) { return nil } p := testProxy(t) - p.handler = h + p.Handler = h m := dns.Msg{} m.Id = dns.Id() @@ -191,7 +191,8 @@ func TestProxyWithResolvers(t *testing.T) { } func TestProxyWithCache(t *testing.T) { - p := testProxyWithOptions(t, ProxyOptions{CacheSize: 10, CacheExpiryInterval: time.Minute}) + p := testProxy(t) + p.cache = cache.New(10, time.Minute) p.resolvers = []string{"resolver1"} client := make(testClient) p.client = client @@ -212,8 +213,8 @@ func TestProxyWithCache(t *testing.T) { } func TestProxyWithLogging(t *testing.T) { - log := &testLogger{} - p, err := NewProxy(ProxyOptions{Logger: log}) + logger := &testLogger{} + p, err := NewProxy(cache.New(0, time.Minute), logger, ProxyOptions{}) if err != nil { t.Fatal(err) } @@ -234,7 +235,7 @@ func TestProxyWithLogging(t *testing.T) { } return nil } - p.handler = h + p.Handler = h var tests = []struct { question string @@ -250,8 +251,8 @@ func TestProxyWithLogging(t *testing.T) { {goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard}, } for i, tt := range tests { - log.question = "" - log.remoteAddr = nil + logger.question = "" + logger.remoteAddr = nil p.logMode = tt.logMode m.SetQuestion(tt.question, dns.TypeA) if tt.question == badHost { @@ -260,18 +261,18 @@ func TestProxyWithLogging(t *testing.T) { assertRR(t, p, &m, "192.0.2.1") } if tt.log { - if log.question != tt.question { - t.Errorf("#%d: question = %q, want %q", i, log.question, tt.question) + if logger.question != tt.question { + t.Errorf("#%d: question = %q, want %q", i, logger.question, tt.question) } - if log.remoteAddr.String() != tt.remoteAddr.String() { - t.Errorf("#%d: remoteAddr = %s, want %s", i, log.remoteAddr, tt.remoteAddr) + if logger.remoteAddr.String() != tt.remoteAddr.String() { + t.Errorf("#%d: remoteAddr = %s, want %s", i, logger.remoteAddr, tt.remoteAddr) } } else { - if log.question != "" { - t.Errorf("#%d: question = %q, want %q", i, log.question, "") + if logger.question != "" { + t.Errorf("#%d: question = %q, want %q", i, logger.question, "") } - if log.remoteAddr != nil { - t.Errorf("#%d: remoteAddr = %v, want %v", i, log.remoteAddr, nil) + if logger.remoteAddr != nil { + t.Errorf("#%d: remoteAddr = %v, want %v", i, logger.remoteAddr, nil) } } } @@ -38,36 +38,22 @@ type Server struct { } // NewServer returns a new server configured according to config. -func NewServer(logger *log.Logger, config Config) (*Server, error) { +func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, error) { server := &Server{ Config: config, done: make(chan bool, 1), logger: logger, + proxy: proxy, httpClient: &http.Client{Timeout: 10 * time.Second}, } + proxy.Handler = server.hijack - // Start goroutines + // Periodically refresh hosts if t := config.DNS.refreshInterval; t > 0 { server.ticker = time.NewTicker(t) go server.reloadHosts() } - // Configure proxy - var err error - server.proxy, err = dns.NewProxy(dns.ProxyOptions{ - Handler: server.hijack, - Resolvers: config.DNS.Resolvers, - Logger: logger, - LogMode: config.DNS.logMode, - Network: config.Resolver.Protocol, - Timeout: config.Resolver.timeout, - CacheSize: config.DNS.CacheSize, - CacheExpiryInterval: config.DNS.cacheExpiryInterval, - }) - if err != nil { - return nil, err - } - // Load initial hosts server.loadHosts() return server, nil diff --git a/server_test.go b/server_test.go index 837a74e..777c8fd 100644 --- a/server_test.go +++ b/server_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns" "github.com/mpolden/zdns/hosts" "github.com/mpolden/zdns/log" @@ -86,26 +87,30 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) { defer cleanup() t.Fatal(err) } - conf := Config{ + config := Config{ DNS: DNSOptions{Listen: "0.0.0.0:53", hijackMode: HijackZero, refreshInterval: refreshInterval, }, - Resolver: ResolverOptions{Timeout: "0"}, + Resolver: ResolverOptions{TimeoutString: "0"}, Hosts: []Hosts{ {URL: httpSrv.URL, Hijack: true}, {URL: "file://" + file, Hijack: true}, {Hosts: []string{"192.0.2.5 badhost5"}}, }, } - if err := conf.load(); err != nil { + if err := config.load(); err != nil { t.Fatal(err) } - log, err := log.New(ioutil.Discard, "", log.RecordOptions{}) + logger, err := log.New(ioutil.Discard, "", log.RecordOptions{}) if err != nil { t.Fatal(err) } - srv, err = NewServer(log, conf) + proxy, err := dns.NewProxy(cache.New(0, time.Minute), logger, dns.ProxyOptions{}) + if err != nil { + t.Fatal(err) + } + srv, err = NewServer(logger, proxy, config) if err != nil { defer cleanup() t.Fatal(err) |