aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-26 13:14:00 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-26 13:14:00 +0100
commit69963c803f2bb28437a997646ab60214de5151ea (patch)
treeaeaddc0b83afb7e94e72c7e9d7d5051396a43d81
parente77437a045ffc4e99665646a46326bd64f49b6a7 (diff)
Decouple dependencies
-rw-r--r--cache/cache.go7
-rw-r--r--cache/cache_test.go20
-rw-r--r--cmd/zdns/main.go19
-rw-r--r--cmd/zdns/main_test.go2
-rw-r--r--config.go70
-rw-r--r--config_test.go6
-rw-r--r--dns/proxy.go29
-rw-r--r--dns/proxy_test.go41
-rw-r--r--server.go22
-rw-r--r--server_test.go15
10 files changed, 109 insertions, 122 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 9a5d72a..bf7cfae 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -2,7 +2,6 @@ package cache
import (
"encoding/binary"
- "fmt"
"hash/fnv"
"net"
"sync"
@@ -64,9 +63,9 @@ type Value struct {
func (v *Value) TTL() time.Duration { return minTTL(v.msg) }
// New creates a new cache of given capacity. Stale cache entries are removed at expiryInterval.
-func New(capacity int, expiryInterval time.Duration) (*Cache, error) {
+func New(capacity int, expiryInterval time.Duration) *Cache {
if capacity < 0 {
- return nil, fmt.Errorf("invalid capacity: %d", capacity)
+ capacity = 0
}
if expiryInterval == 0 {
expiryInterval = 10 * time.Minute
@@ -77,7 +76,7 @@ func New(capacity int, expiryInterval time.Duration) (*Cache, error) {
entries: make(map[uint32]*Value, capacity),
}
maintain(cache, expiryInterval)
- return cache, nil
+ return cache
}
// NewKey creates a new cache key for the DNS name, qtype and qclass
diff --git a/cache/cache_test.go b/cache/cache_test.go
index f0d61e6..11936eb 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -84,10 +84,7 @@ func TestCache(t *testing.T) {
msgFailure.Rcode = dns.RcodeServerFailure
tt := date(2019, 1, 1)
- c, err := New(100, time.Duration(10*time.Millisecond))
- if err != nil {
- t.Fatal(err)
- }
+ c := New(100, time.Duration(10*time.Millisecond))
defer handleErr(t, c.Close)
var tests = []struct {
msg *dns.Msg
@@ -129,10 +126,7 @@ func TestCacheCapacity(t *testing.T) {
{3, 2, 2},
}
for i, tt := range tests {
- c, err := New(tt.capacity, 10*time.Minute)
- if err != nil {
- t.Fatal(err)
- }
+ c := New(tt.capacity, 10*time.Minute)
defer handleErr(t, c.Close)
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
@@ -170,10 +164,7 @@ func TestCacheList(t *testing.T) {
{2, 3, 2},
}
for i, tt := range tests {
- c, err := New(1024, 10*time.Minute)
- if err != nil {
- t.Fatal(err)
- }
+ c := New(1024, 10*time.Minute)
defer handleErr(t, c.Close)
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
@@ -207,10 +198,7 @@ func BenchmarkNewKey(b *testing.B) {
}
func BenchmarkCache(b *testing.B) {
- c, err := New(1000, 10*time.Minute)
- if err != nil {
- b.Fatal(err)
- }
+ c := New(1000, 10*time.Minute)
b.ResetTimer()
for n := 0; n < b.N; n++ {
c.Set(uint32(n), &dns.Msg{})
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index 623175c..0a9dacf 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -10,6 +10,8 @@ import (
"flag"
"github.com/mpolden/zdns"
+ "github.com/mpolden/zdns/cache"
+ "github.com/mpolden/zdns/dns"
"github.com/mpolden/zdns/http"
"github.com/mpolden/zdns/log"
"github.com/mpolden/zdns/signal"
@@ -71,18 +73,33 @@ func (c *cli) run() {
return
}
+ // Config
config, err := readConfig(*confFile)
fatal(err)
+ // Logger
logger, err := log.New(c.out, logPrefix, log.RecordOptions{
Database: config.DNS.LogDatabase,
TTL: config.DNS.LogTTL,
})
fatal(err)
+ // Signal handling
sigHandler := signal.NewHandler(c.signal, logger)
- dnsSrv, err := zdns.NewServer(logger, config)
+ // Cache
+ cache := cache.New(config.DNS.CacheSize, config.DNS.CacheExpiryInterval)
+
+ // DNS server
+ proxy, err := dns.NewProxy(cache, logger, dns.ProxyOptions{
+ Resolvers: config.DNS.Resolvers,
+ LogMode: config.DNS.LogMode,
+ Network: config.Resolver.Protocol,
+ Timeout: config.Resolver.Timeout,
+ })
+ fatal(err)
+
+ dnsSrv, err := zdns.NewServer(logger, proxy, config)
fatal(err)
sigHandler.OnReload(dnsSrv)
sigHandler.OnClose(dnsSrv)
diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go
index f2dfa25..83bf1aa 100644
--- a/cmd/zdns/main_test.go
+++ b/cmd/zdns/main_test.go
@@ -29,7 +29,7 @@ func tempFile(t *testing.T, s string) (string, error) {
func TestMain(t *testing.T) {
conf := `
[dns]
-listen = "0.0.0.0:0"
+listen = "127.0.0.1:0"
[resolver]
protocol = "udp"
diff --git a/config.go b/config.go
index 0335a18..7781c42 100644
--- a/config.go
+++ b/config.go
@@ -22,29 +22,29 @@ type Config struct {
// DNSOptions controlers the behaviour of the DNS server.
type DNSOptions struct {
- Listen string
- Protocol string `toml:"protocol"`
- CacheExpiryInterval string `toml:"cache_expiry_interval"`
- cacheExpiryInterval time.Duration
- CacheSize int `toml:"cache_size"`
- HijackMode string `toml:"hijack_mode"`
- hijackMode int
- RefreshInterval string `toml:"hosts_refresh_interval"`
- refreshInterval time.Duration
- Resolvers []string
- LogDatabase string `toml:"log_database"`
- LogMode string `toml:"log_mode"`
- logMode int
- LogTTLString string `toml:"log_ttl"`
- LogTTL time.Duration
- ListenHTTP string `toml:"listen_http"`
+ Listen string
+ Protocol string `toml:"protocol"`
+ CacheExpiryIntervalString string `toml:"cache_expiry_interval"`
+ CacheExpiryInterval time.Duration
+ CacheSize int `toml:"cache_size"`
+ HijackMode string `toml:"hijack_mode"`
+ hijackMode int
+ RefreshInterval string `toml:"hosts_refresh_interval"`
+ refreshInterval time.Duration
+ Resolvers []string
+ LogDatabase string `toml:"log_database"`
+ LogModeString string `toml:"log_mode"`
+ LogMode int
+ LogTTLString string `toml:"log_ttl"`
+ LogTTL time.Duration
+ ListenHTTP string `toml:"listen_http"`
}
// ResolverOptions controls the behaviour of resolvers.
type ResolverOptions struct {
- Protocol string `toml:"protocol"`
- Timeout string `toml:"timeout"`
- timeout time.Duration
+ Protocol string `toml:"protocol"`
+ TimeoutString string `toml:"timeout"`
+ Timeout time.Duration
}
// Hosts controls how a hosts file should be retrieved.
@@ -63,7 +63,7 @@ func newConfig() Config {
c.DNS.Protocol = "udp"
c.DNS.CacheSize = 1024
c.DNS.RefreshInterval = "48h"
- c.Resolver.Timeout = "5s"
+ c.Resolver.TimeoutString = "5s"
c.Resolver.Protocol = "udp"
return c
}
@@ -82,10 +82,10 @@ func (c *Config) load() error {
if c.DNS.CacheSize < 0 {
return fmt.Errorf("cache size must be >= 0")
}
- if c.DNS.CacheExpiryInterval == "" {
- c.DNS.CacheExpiryInterval = "15m"
+ if c.DNS.CacheExpiryIntervalString == "" {
+ c.DNS.CacheExpiryIntervalString = "15m"
}
- c.DNS.cacheExpiryInterval, err = time.ParseDuration(c.DNS.CacheExpiryInterval)
+ c.DNS.CacheExpiryInterval, err = time.ParseDuration(c.DNS.CacheExpiryIntervalString)
if err != nil {
return fmt.Errorf("invalid cache expiry interval: %s", err)
}
@@ -159,28 +159,28 @@ func (c *Config) load() error {
default:
return fmt.Errorf("invalid resolver protocol: %s", c.Resolver.Protocol)
}
- c.Resolver.timeout, err = time.ParseDuration(c.Resolver.Timeout)
+ c.Resolver.Timeout, err = time.ParseDuration(c.Resolver.TimeoutString)
if err != nil {
- return fmt.Errorf("invalid resolver timeout: %s", c.Resolver.Timeout)
+ return fmt.Errorf("invalid resolver timeout: %s", c.Resolver.TimeoutString)
}
- if c.Resolver.timeout < 0 {
+ if c.Resolver.Timeout < 0 {
return fmt.Errorf("resolver timeout must be >= 0")
}
- if c.Resolver.timeout == 0 {
- c.Resolver.timeout = 5 * time.Second
+ if c.Resolver.Timeout == 0 {
+ c.Resolver.Timeout = 5 * time.Second
}
- switch c.DNS.LogMode {
+ switch c.DNS.LogModeString {
case "":
- c.DNS.logMode = dns.LogDiscard
+ c.DNS.LogMode = dns.LogDiscard
case "all":
- c.DNS.logMode = dns.LogAll
+ c.DNS.LogMode = dns.LogAll
case "hijacked":
- c.DNS.logMode = dns.LogHijacked
+ c.DNS.LogMode = dns.LogHijacked
default:
- return fmt.Errorf("invalid log mode: %s", c.DNS.LogMode)
+ return fmt.Errorf("invalid log mode: %s", c.DNS.LogModeString)
}
- if c.DNS.LogMode != "" && c.DNS.LogDatabase == "" {
- return fmt.Errorf("log_mode = %q requires log_database to be set", c.DNS.LogMode)
+ if c.DNS.LogModeString != "" && c.DNS.LogDatabase == "" {
+ return fmt.Errorf("log_mode = %q requires log_database to be set", c.DNS.LogModeString)
}
if c.DNS.LogTTLString == "" {
c.DNS.LogTTLString = "0"
diff --git a/config_test.go b/config_test.go
index cf2e67e..12b8b6e 100644
--- a/config_test.go
+++ b/config_test.go
@@ -56,9 +56,9 @@ hijack = false
want int
}{
{"DNS.CacheSize", conf.DNS.CacheSize, 2048},
- {"DNS.CacheExpiryInterval", int(conf.DNS.cacheExpiryInterval), int(5 * time.Minute)},
+ {"DNS.CacheExpiryInterval", int(conf.DNS.CacheExpiryInterval), int(5 * time.Minute)},
{"len(DNS.Resolvers)", len(conf.DNS.Resolvers), 2},
- {"Resolver.Timeout", int(conf.Resolver.timeout), int(time.Second)},
+ {"Resolver.Timeout", int(conf.Resolver.Timeout), int(time.Second)},
{"DNS.RefreshInterval", int(conf.DNS.refreshInterval), int(48 * time.Hour)},
{"len(Hosts)", len(conf.Hosts), 3},
{"DNS.LogTTL", int(conf.DNS.LogTTL), int(72 * time.Hour)},
@@ -80,7 +80,7 @@ hijack = false
{"DNS.Resolvers[1]", conf.DNS.Resolvers[1], "192.0.2.2:53"},
{"DNS.HijackMode", conf.DNS.HijackMode, "zero"},
{"DNS.LogDatabase", conf.DNS.LogDatabase, "/tmp/log.db"},
- {"DNS.LogMode", conf.DNS.LogMode, "all"},
+ {"DNS.LogMode", conf.DNS.LogModeString, "all"},
{"DNS.LogTTL", conf.DNS.LogTTLString, "72h"},
{"Resolver.Protocol", conf.Resolver.Protocol, "tcp-tls"},
{"Hosts[0].Source", conf.Hosts[0].URL, "file:///home/foo/hosts-good"},
diff --git a/dns/proxy.go b/dns/proxy.go
index ad81c6d..5ea23ed 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -36,7 +36,7 @@ type Handler func(*Request) *Reply
// Proxy represents a DNS proxy.
type Proxy struct {
- handler Handler
+ Handler Handler
resolvers []string
cache *cache.Cache
logger logger
@@ -47,14 +47,10 @@ type Proxy struct {
// ProxyOptions represents proxy configuration.
type ProxyOptions struct {
- Handler Handler
- Resolvers []string
- Logger logger
- LogMode int
- Network string
- Timeout time.Duration
- CacheSize int
- CacheExpiryInterval time.Duration
+ Resolvers []string
+ LogMode int
+ Network string
+ Timeout time.Duration
}
type client interface {
@@ -68,17 +64,12 @@ type logger interface {
}
// NewProxy creates a new DNS proxy.
-func NewProxy(options ProxyOptions) (*Proxy, error) {
- cache, err := cache.New(options.CacheSize, options.CacheExpiryInterval)
- if err != nil {
- return nil, err
- }
+func NewProxy(cache *cache.Cache, logger logger, options ProxyOptions) (*Proxy, error) {
return &Proxy{
- handler: options.Handler,
+ logger: logger,
+ cache: cache,
resolvers: options.Resolvers,
- logger: options.Logger,
logMode: options.LogMode,
- cache: cache,
client: &dns.Client{Net: options.Network, Timeout: options.Timeout},
}, nil
}
@@ -119,10 +110,10 @@ func (r *Reply) String() string {
}
func (p *Proxy) reply(r *dns.Msg) *dns.Msg {
- if p.handler == nil || len(r.Question) != 1 {
+ if p.Handler == nil || len(r.Question) != 1 {
return nil
}
- reply := p.handler(&Request{
+ reply := p.Handler(&Request{
Name: r.Question[0].Name,
Type: r.Question[0].Qtype,
})
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index 2727b8e..41237a6 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -60,15 +60,15 @@ func (l *testLogger) Record(remoteAddr net.IP, qtype uint16, question, answer st
l.remoteAddr = remoteAddr
}
-func testProxy(t *testing.T) *Proxy { return testProxyWithOptions(t, ProxyOptions{}) }
-
-func testProxyWithOptions(t *testing.T, options ProxyOptions) *Proxy {
+func testProxy(t *testing.T) *Proxy {
log, err := log.New(ioutil.Discard, "", log.RecordOptions{})
if err != nil {
t.Fatal(err)
}
- options.Logger = log
- proxy, err := NewProxy(options)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxy, err := NewProxy(cache.New(0, time.Minute), log, ProxyOptions{})
if err != nil {
t.Fatal(err)
}
@@ -143,7 +143,7 @@ func TestProxy(t *testing.T) {
return nil
}
p := testProxy(t)
- p.handler = h
+ p.Handler = h
m := dns.Msg{}
m.Id = dns.Id()
@@ -191,7 +191,8 @@ func TestProxyWithResolvers(t *testing.T) {
}
func TestProxyWithCache(t *testing.T) {
- p := testProxyWithOptions(t, ProxyOptions{CacheSize: 10, CacheExpiryInterval: time.Minute})
+ p := testProxy(t)
+ p.cache = cache.New(10, time.Minute)
p.resolvers = []string{"resolver1"}
client := make(testClient)
p.client = client
@@ -212,8 +213,8 @@ func TestProxyWithCache(t *testing.T) {
}
func TestProxyWithLogging(t *testing.T) {
- log := &testLogger{}
- p, err := NewProxy(ProxyOptions{Logger: log})
+ logger := &testLogger{}
+ p, err := NewProxy(cache.New(0, time.Minute), logger, ProxyOptions{})
if err != nil {
t.Fatal(err)
}
@@ -234,7 +235,7 @@ func TestProxyWithLogging(t *testing.T) {
}
return nil
}
- p.handler = h
+ p.Handler = h
var tests = []struct {
question string
@@ -250,8 +251,8 @@ func TestProxyWithLogging(t *testing.T) {
{goodHost, net.IPv4(192, 0, 2, 100), false, LogDiscard},
}
for i, tt := range tests {
- log.question = ""
- log.remoteAddr = nil
+ logger.question = ""
+ logger.remoteAddr = nil
p.logMode = tt.logMode
m.SetQuestion(tt.question, dns.TypeA)
if tt.question == badHost {
@@ -260,18 +261,18 @@ func TestProxyWithLogging(t *testing.T) {
assertRR(t, p, &m, "192.0.2.1")
}
if tt.log {
- if log.question != tt.question {
- t.Errorf("#%d: question = %q, want %q", i, log.question, tt.question)
+ if logger.question != tt.question {
+ t.Errorf("#%d: question = %q, want %q", i, logger.question, tt.question)
}
- if log.remoteAddr.String() != tt.remoteAddr.String() {
- t.Errorf("#%d: remoteAddr = %s, want %s", i, log.remoteAddr, tt.remoteAddr)
+ if logger.remoteAddr.String() != tt.remoteAddr.String() {
+ t.Errorf("#%d: remoteAddr = %s, want %s", i, logger.remoteAddr, tt.remoteAddr)
}
} else {
- if log.question != "" {
- t.Errorf("#%d: question = %q, want %q", i, log.question, "")
+ if logger.question != "" {
+ t.Errorf("#%d: question = %q, want %q", i, logger.question, "")
}
- if log.remoteAddr != nil {
- t.Errorf("#%d: remoteAddr = %v, want %v", i, log.remoteAddr, nil)
+ if logger.remoteAddr != nil {
+ t.Errorf("#%d: remoteAddr = %v, want %v", i, logger.remoteAddr, nil)
}
}
}
diff --git a/server.go b/server.go
index 29c7c6e..eb05b51 100644
--- a/server.go
+++ b/server.go
@@ -38,36 +38,22 @@ type Server struct {
}
// NewServer returns a new server configured according to config.
-func NewServer(logger *log.Logger, config Config) (*Server, error) {
+func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, error) {
server := &Server{
Config: config,
done: make(chan bool, 1),
logger: logger,
+ proxy: proxy,
httpClient: &http.Client{Timeout: 10 * time.Second},
}
+ proxy.Handler = server.hijack
- // Start goroutines
+ // Periodically refresh hosts
if t := config.DNS.refreshInterval; t > 0 {
server.ticker = time.NewTicker(t)
go server.reloadHosts()
}
- // Configure proxy
- var err error
- server.proxy, err = dns.NewProxy(dns.ProxyOptions{
- Handler: server.hijack,
- Resolvers: config.DNS.Resolvers,
- Logger: logger,
- LogMode: config.DNS.logMode,
- Network: config.Resolver.Protocol,
- Timeout: config.Resolver.timeout,
- CacheSize: config.DNS.CacheSize,
- CacheExpiryInterval: config.DNS.cacheExpiryInterval,
- })
- if err != nil {
- return nil, err
- }
-
// Load initial hosts
server.loadHosts()
return server, nil
diff --git a/server_test.go b/server_test.go
index 837a74e..777c8fd 100644
--- a/server_test.go
+++ b/server_test.go
@@ -10,6 +10,7 @@ import (
"testing"
"time"
+ "github.com/mpolden/zdns/cache"
"github.com/mpolden/zdns/dns"
"github.com/mpolden/zdns/hosts"
"github.com/mpolden/zdns/log"
@@ -86,26 +87,30 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) {
defer cleanup()
t.Fatal(err)
}
- conf := Config{
+ config := Config{
DNS: DNSOptions{Listen: "0.0.0.0:53",
hijackMode: HijackZero,
refreshInterval: refreshInterval,
},
- Resolver: ResolverOptions{Timeout: "0"},
+ Resolver: ResolverOptions{TimeoutString: "0"},
Hosts: []Hosts{
{URL: httpSrv.URL, Hijack: true},
{URL: "file://" + file, Hijack: true},
{Hosts: []string{"192.0.2.5 badhost5"}},
},
}
- if err := conf.load(); err != nil {
+ if err := config.load(); err != nil {
t.Fatal(err)
}
- log, err := log.New(ioutil.Discard, "", log.RecordOptions{})
+ logger, err := log.New(ioutil.Discard, "", log.RecordOptions{})
if err != nil {
t.Fatal(err)
}
- srv, err = NewServer(log, conf)
+ proxy, err := dns.NewProxy(cache.New(0, time.Minute), logger, dns.ProxyOptions{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ srv, err = NewServer(logger, proxy, config)
if err != nil {
defer cleanup()
t.Fatal(err)