From bd0ec7519ebbfbd92a8c3063da0e8a823897171c Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sun, 29 Dec 2019 21:07:53 +0100 Subject: Extract Client implementation --- cmd/zdns/main.go | 10 +++++----- dns/dnsutil/dnsutil.go | 42 ++++++++++++++++++++++++++++++------------ dns/proxy.go | 43 ++++++++++--------------------------------- dns/proxy_test.go | 37 +++++++++++++++++++------------------ server_test.go | 2 +- 5 files changed, 65 insertions(+), 69 deletions(-) diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 23e12a0..4b2697d 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -12,6 +12,7 @@ import ( "github.com/mpolden/zdns" "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns" + "github.com/mpolden/zdns/dns/dnsutil" "github.com/mpolden/zdns/http" "github.com/mpolden/zdns/log" "github.com/mpolden/zdns/signal" @@ -93,12 +94,11 @@ func (c *cli) run() { cache := cache.New(config.DNS.CacheSize) sigHandler.OnClose(cache) + // Client + client := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) + // DNS server - proxy, err := dns.NewProxy(cache, logger, dns.ProxyOptions{ - Resolvers: config.DNS.Resolvers, - Network: config.Resolver.Protocol, - Timeout: config.Resolver.Timeout, - }) + proxy, err := dns.NewProxy(cache, client, logger) fatal(err) sigHandler.OnClose(proxy) diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go index 13cfd0a..141f5c6 100644 --- a/dns/dnsutil/dnsutil.go +++ b/dns/dnsutil/dnsutil.go @@ -6,6 +6,7 @@ import ( "time" "github.com/miekg/dns" + "github.com/mpolden/zdns/dns/http" ) var ( @@ -16,28 +17,45 @@ var ( RcodeToString = dns.RcodeToString ) -// Resolver is the interface that wraps the Exchange method of a DNS client. -type Resolver interface { +// Exchanger is the interface that wraps the Exchange method of a DNS client. +type Exchanger interface { Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) } -// Exchange sends a DNS query to addr and returns the response. If more than one addr is given, all are queried and the -// first successful response is returned. -func Exchange(resolver Resolver, msg *dns.Msg, addr ...string) (*dns.Msg, error) { - done := make(chan bool) - c := make(chan *dns.Msg) +// Client wraps a DNS client and a list of server addresses. +type Client struct { + Exchanger Exchanger + Addresses []string +} + +// NewClient creates a new Client using the named network and addresses. +func NewClient(network string, timeout time.Duration, addresses ...string) *Client { + var client Exchanger + if network == "https" { + client = http.NewClient(timeout) + } else { + client = &dns.Client{Net: network, Timeout: timeout} + } + return &Client{Exchanger: client, Addresses: addresses} +} + +// Exchange performs a synchronous DNS query. All addresses in Client c are queried in parallel and the first successful +// response is returned. +func (c *Client) Exchange(msg *dns.Msg) (*dns.Msg, error) { + done := make(chan bool, 1) + ch := make(chan *dns.Msg, len(c.Addresses)) var wg sync.WaitGroup - wg.Add(len(addr)) + wg.Add(len(c.Addresses)) err := errors.New("addr is empty") - for _, a := range addr { + for _, a := range c.Addresses { go func(addr string) { defer wg.Done() - r, _, err1 := resolver.Exchange(msg, addr) + r, _, err1 := c.Exchanger.Exchange(msg, addr) if err1 != nil { err = err1 return } - c <- r + ch <- r }(a) } go func() { @@ -48,7 +66,7 @@ func Exchange(resolver Resolver, msg *dns.Msg, addr ...string) (*dns.Msg, error) select { case <-done: return nil, err - case rr := <-c: + case rr := <-ch: return rr, nil } } diff --git a/dns/proxy.go b/dns/proxy.go index d5b73da..060cb61 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -4,12 +4,10 @@ import ( "fmt" "net" "strings" - "time" "github.com/miekg/dns" "github.com/mpolden/zdns/cache" "github.com/mpolden/zdns/dns/dnsutil" - "github.com/mpolden/zdns/dns/http" ) const ( @@ -33,24 +31,11 @@ type Handler func(*Request) *Reply // Proxy represents a DNS proxy. type Proxy struct { - Handler Handler - resolvers []string - cache *cache.Cache - logger logger - server *dns.Server - client client - timeout time.Duration -} - -// ProxyOptions represents proxy configuration. -type ProxyOptions struct { - Resolvers []string - Network string - Timeout time.Duration -} - -type client interface { - Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) + Handler Handler + cache *cache.Cache + logger logger + server *dns.Server + client *dnsutil.Client } type logger interface { @@ -60,19 +45,11 @@ type logger interface { } // NewProxy creates a new DNS proxy. -func NewProxy(cache *cache.Cache, logger logger, options ProxyOptions) (*Proxy, error) { - var c client - if options.Network == "https" { - c = http.NewClient(options.Timeout) - } else { - c = &dns.Client{Net: options.Network, Timeout: options.Timeout} - } +func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger logger) (*Proxy, error) { return &Proxy{ - logger: logger, - cache: cache, - resolvers: options.Resolvers, - client: c, - timeout: options.Timeout, + logger: logger, + cache: cache, + client: client, }, nil } @@ -164,7 +141,7 @@ func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { p.writeMsg(w, msg, false) return } - rr, err := dnsutil.Exchange(p.client, r, p.resolvers...) + rr, err := p.client.Exchange(r) if err == nil { p.cache.Set(key, rr) p.writeMsg(w, rr, false) diff --git a/dns/proxy_test.go b/dns/proxy_test.go index aa265d7..5fa0646 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -10,6 +10,7 @@ import ( "github.com/miekg/dns" "github.com/mpolden/zdns/cache" + "github.com/mpolden/zdns/dns/dnsutil" "github.com/mpolden/zdns/log" ) @@ -30,22 +31,22 @@ func (w *dnsWriter) WriteMsg(msg *dns.Msg) error { return nil } -type resolver struct { +type response struct { answer *dns.Msg fail bool } -type testClient map[string]*resolver +type testExchanger map[string]*response -func (c testClient) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { - r, ok := c[addr] +func (e testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { + r, ok := e[addr] if !ok { panic("no such resolver: " + addr) } if r.fail { return nil, 0, fmt.Errorf("%s SERVFAIL", addr) } - return r.answer, time.Minute * 5, nil + return r.answer, time.Second, nil } func testProxy(t *testing.T) *Proxy { @@ -56,7 +57,7 @@ func testProxy(t *testing.T) *Proxy { if err != nil { t.Fatal(err) } - proxy, err := NewProxy(cache.New(0), log, ProxyOptions{Timeout: 2 * time.Second}) + proxy, err := NewProxy(cache.New(0), nil, log) if err != nil { t.Fatal(err) } @@ -147,47 +148,47 @@ func TestProxy(t *testing.T) { func TestProxyWithResolvers(t *testing.T) { p := testProxy(t) - client := make(testClient) - p.client = client + exchanger := make(testExchanger) + p.client = &dnsutil.Client{Exchanger: exchanger} defer p.Close() // No resolvers assertFailure(t, p, TypeA, "host1") // First and only resolver responds succesfully - p.resolvers = []string{"resolver1"} + p.client.Addresses = []string{"resolver1"} reply := ReplyA("host1", net.ParseIP("192.0.2.1")) m := dns.Msg{} m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - client["resolver1"] = &resolver{answer: &m} + exchanger["resolver1"] = &response{answer: &m} assertRR(t, p, &m, "192.0.2.1") // First and only resolver fails - client["resolver1"].fail = true + exchanger["resolver1"].fail = true assertFailure(t, p, TypeA, "host1") // First resolver fails, but second succeeds reply = ReplyA("host1", net.ParseIP("192.0.2.2")) - p.resolvers = []string{"resolver1", "resolver2"} + p.client.Addresses = []string{"resolver1", "resolver2"} m = dns.Msg{} m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - client["resolver2"] = &resolver{answer: &m} + exchanger["resolver2"] = &response{answer: &m} assertRR(t, p, &m, "192.0.2.2") // All resolvers fail - client["resolver2"].fail = true + exchanger["resolver2"].fail = true assertFailure(t, p, TypeA, "host1") } func TestProxyWithCache(t *testing.T) { p := testProxy(t) p.cache = cache.New(10) - p.resolvers = []string{"resolver1"} - client := make(testClient) - p.client = client + exchanger := make(testExchanger) + p.client = &dnsutil.Client{Exchanger: exchanger} + p.client.Addresses = []string{"resolver1"} defer p.Close() reply := ReplyA("host1", net.ParseIP("192.0.2.1")) @@ -195,7 +196,7 @@ func TestProxyWithCache(t *testing.T) { m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - client["resolver1"] = &resolver{answer: &m} + exchanger["resolver1"] = &response{answer: &m} assertRR(t, p, &m, "192.0.2.1") k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET) diff --git a/server_test.go b/server_test.go index 6706e3c..657f68f 100644 --- a/server_test.go +++ b/server_test.go @@ -106,7 +106,7 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) { if err != nil { t.Fatal(err) } - proxy, err := dns.NewProxy(cache.New(0), logger, dns.ProxyOptions{}) + proxy, err := dns.NewProxy(cache.New(0), nil, logger) if err != nil { t.Fatal(err) } -- cgit v1.2.3