aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-29 18:06:00 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-29 18:19:05 +0100
commit72ee46698a94c48527184109401e8a6725a4674b (patch)
treefb2f8c5a8fc1856bd083b895b2edc1876c3ef24d
parent935077242b9a7c57f312515fc99bc57bd48d6679 (diff)
Query all resolvers in parallel
-rw-r--r--dns/dnsutil/dnsutil.go40
-rw-r--r--dns/proxy.go19
-rw-r--r--dns/proxy_test.go8
3 files changed, 52 insertions, 15 deletions
diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go
index 6f791ef..13cfd0a 100644
--- a/dns/dnsutil/dnsutil.go
+++ b/dns/dnsutil/dnsutil.go
@@ -1,6 +1,8 @@
package dnsutil
import (
+ "errors"
+ "sync"
"time"
"github.com/miekg/dns"
@@ -14,6 +16,44 @@ var (
RcodeToString = dns.RcodeToString
)
+// Resolver is the interface that wraps the Exchange method of a DNS client.
+type Resolver interface {
+ Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error)
+}
+
+// Exchange sends a DNS query to addr and returns the response. If more than one addr is given, all are queried and the
+// first successful response is returned.
+func Exchange(resolver Resolver, msg *dns.Msg, addr ...string) (*dns.Msg, error) {
+ done := make(chan bool)
+ c := make(chan *dns.Msg)
+ var wg sync.WaitGroup
+ wg.Add(len(addr))
+ err := errors.New("addr is empty")
+ for _, a := range addr {
+ go func(addr string) {
+ defer wg.Done()
+ r, _, err1 := resolver.Exchange(msg, addr)
+ if err1 != nil {
+ err = err1
+ return
+ }
+ c <- r
+ }(a)
+ }
+ go func() {
+ wg.Wait()
+ done <- true
+ }()
+ for {
+ select {
+ case <-done:
+ return nil, err
+ case rr := <-c:
+ return rr, nil
+ }
+ }
+}
+
// Answers returns all values in the answer section of DNS message msg.
func Answers(msg *dns.Msg) []string {
var answers []string
diff --git a/dns/proxy.go b/dns/proxy.go
index d4ded21..84f1b48 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -46,6 +46,7 @@ type Proxy struct {
logMode int
server *dns.Server
client client
+ timeout time.Duration
}
// ProxyOptions represents proxy configuration.
@@ -80,6 +81,7 @@ func NewProxy(cache *cache.Cache, logger logger, options ProxyOptions) (*Proxy,
resolvers: options.Resolvers,
logMode: options.LogMode,
client: c,
+ timeout: options.Timeout,
}, nil
}
@@ -173,21 +175,14 @@ func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
p.writeMsg(w, msg, false)
return
}
- for i, resolver := range p.resolvers {
- rr, _, err := p.client.Exchange(r, resolver)
- if err != nil {
- p.logger.Printf("resolver %s failed: %s", resolver, err)
- if i == len(p.resolvers)-1 { // No more resolvers to try
- break
- } else {
- continue
- }
- }
+ rr, err := dnsutil.Exchange(p.client, r, p.resolvers...)
+ if err == nil {
p.cache.Set(key, rr)
p.writeMsg(w, rr, false)
- return
+ } else {
+ p.logger.Printf("resolver(s) failed: %s", err)
+ dns.HandleFailed(w, r)
}
- dns.HandleFailed(w, r)
}
// ListenAndServe listens on the network address addr and uses the server to process requests.
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index c675f90..5df7611 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -68,7 +68,7 @@ func testProxy(t *testing.T) *Proxy {
if err != nil {
t.Fatal(err)
}
- proxy, err := NewProxy(cache.New(0), log, ProxyOptions{})
+ proxy, err := NewProxy(cache.New(0), log, ProxyOptions{Timeout: 2 * time.Second})
if err != nil {
t.Fatal(err)
}
@@ -159,12 +159,14 @@ func TestProxy(t *testing.T) {
func TestProxyWithResolvers(t *testing.T) {
p := testProxy(t)
- p.resolvers = []string{"resolver1"}
client := make(testClient)
p.client = client
defer p.Close()
+ // No resolvers
+ assertFailure(t, p, TypeA, "host1")
// First and only resolver responds succesfully
+ p.resolvers = []string{"resolver1"}
reply := ReplyA("host1", net.ParseIP("192.0.2.1"))
m := dns.Msg{}
m.Id = dns.Id()
@@ -217,7 +219,7 @@ func TestProxyWithCache(t *testing.T) {
func TestProxyWithLogging(t *testing.T) {
logger := &testLogger{}
- p, err := NewProxy(cache.New(0), logger, ProxyOptions{})
+ p, err := NewProxy(cache.New(0), logger, ProxyOptions{Timeout: 2 * time.Second})
if err != nil {
t.Fatal(err)
}