From 72ee46698a94c48527184109401e8a6725a4674b Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sun, 29 Dec 2019 18:06:00 +0100 Subject: Query all resolvers in parallel --- dns/dnsutil/dnsutil.go | 40 ++++++++++++++++++++++++++++++++++++++++ dns/proxy.go | 19 +++++++------------ dns/proxy_test.go | 8 +++++--- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go index 6f791ef..13cfd0a 100644 --- a/dns/dnsutil/dnsutil.go +++ b/dns/dnsutil/dnsutil.go @@ -1,6 +1,8 @@ package dnsutil import ( + "errors" + "sync" "time" "github.com/miekg/dns" @@ -14,6 +16,44 @@ var ( RcodeToString = dns.RcodeToString ) +// Resolver is the interface that wraps the Exchange method of a DNS client. +type Resolver interface { + Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) +} + +// Exchange sends a DNS query to addr and returns the response. If more than one addr is given, all are queried and the +// first successful response is returned. +func Exchange(resolver Resolver, msg *dns.Msg, addr ...string) (*dns.Msg, error) { + done := make(chan bool) + c := make(chan *dns.Msg) + var wg sync.WaitGroup + wg.Add(len(addr)) + err := errors.New("addr is empty") + for _, a := range addr { + go func(addr string) { + defer wg.Done() + r, _, err1 := resolver.Exchange(msg, addr) + if err1 != nil { + err = err1 + return + } + c <- r + }(a) + } + go func() { + wg.Wait() + done <- true + }() + for { + select { + case <-done: + return nil, err + case rr := <-c: + return rr, nil + } + } +} + // Answers returns all values in the answer section of DNS message msg. func Answers(msg *dns.Msg) []string { var answers []string diff --git a/dns/proxy.go b/dns/proxy.go index d4ded21..84f1b48 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -46,6 +46,7 @@ type Proxy struct { logMode int server *dns.Server client client + timeout time.Duration } // ProxyOptions represents proxy configuration. @@ -80,6 +81,7 @@ func NewProxy(cache *cache.Cache, logger logger, options ProxyOptions) (*Proxy, resolvers: options.Resolvers, logMode: options.LogMode, client: c, + timeout: options.Timeout, }, nil } @@ -173,21 +175,14 @@ func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { p.writeMsg(w, msg, false) return } - for i, resolver := range p.resolvers { - rr, _, err := p.client.Exchange(r, resolver) - if err != nil { - p.logger.Printf("resolver %s failed: %s", resolver, err) - if i == len(p.resolvers)-1 { // No more resolvers to try - break - } else { - continue - } - } + rr, err := dnsutil.Exchange(p.client, r, p.resolvers...) + if err == nil { p.cache.Set(key, rr) p.writeMsg(w, rr, false) - return + } else { + p.logger.Printf("resolver(s) failed: %s", err) + dns.HandleFailed(w, r) } - dns.HandleFailed(w, r) } // ListenAndServe listens on the network address addr and uses the server to process requests. diff --git a/dns/proxy_test.go b/dns/proxy_test.go index c675f90..5df7611 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -68,7 +68,7 @@ func testProxy(t *testing.T) *Proxy { if err != nil { t.Fatal(err) } - proxy, err := NewProxy(cache.New(0), log, ProxyOptions{}) + proxy, err := NewProxy(cache.New(0), log, ProxyOptions{Timeout: 2 * time.Second}) if err != nil { t.Fatal(err) } @@ -159,12 +159,14 @@ func TestProxy(t *testing.T) { func TestProxyWithResolvers(t *testing.T) { p := testProxy(t) - p.resolvers = []string{"resolver1"} client := make(testClient) p.client = client defer p.Close() + // No resolvers + assertFailure(t, p, TypeA, "host1") // First and only resolver responds succesfully + p.resolvers = []string{"resolver1"} reply := ReplyA("host1", net.ParseIP("192.0.2.1")) m := dns.Msg{} m.Id = dns.Id() @@ -217,7 +219,7 @@ func TestProxyWithCache(t *testing.T) { func TestProxyWithLogging(t *testing.T) { logger := &testLogger{} - p, err := NewProxy(cache.New(0), logger, ProxyOptions{}) + p, err := NewProxy(cache.New(0), logger, ProxyOptions{Timeout: 2 * time.Second}) if err != nil { t.Fatal(err) } -- cgit v1.2.3