aboutsummaryrefslogtreecommitdiffstats
path: root/dns
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2020-05-09 12:14:29 +0200
committerMartin Polden <mpolden@mpolden.no>2020-05-09 12:14:29 +0200
commita0499585a5b66d2a7eb2d1c3b8b7a9346307e9fa (patch)
tree74c0df116c0cdc928ee1a10369fd6286833a913a /dns
parent17887b5547d972eff47d20f586d6d45cf7d3374a (diff)
Refactor multiplexed client
Diffstat (limited to 'dns')
-rw-r--r--dns/dnsutil/dnsutil.go77
-rw-r--r--dns/dnsutil/dnsutil_test.go42
-rw-r--r--dns/proxy.go4
-rw-r--r--dns/proxy_test.go64
4 files changed, 90 insertions, 97 deletions
diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go
index 9cf75b6..3b989c9 100644
--- a/dns/dnsutil/dnsutil.go
+++ b/dns/dnsutil/dnsutil.go
@@ -17,46 +17,50 @@ var (
RcodeToString = dns.RcodeToString
)
-// Exchanger is the interface that wraps the Exchange method of a DNS client.
-type Exchanger interface {
- Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error)
+// Client is the interface of a DNS client.
+type Client interface {
+ Exchange(*dns.Msg) (*dns.Msg, error)
}
-// Client wraps a DNS client and a list of server addresses.
-type Client struct {
- Exchanger Exchanger
- Addresses []string
+// Config is a structure used to configure a DNS client.
+type Config struct {
+ Network string
+ Timeout time.Duration
}
-// NewClient creates a new Client using the named network and addresses.
-func NewClient(network string, timeout time.Duration, addresses ...string) *Client {
- var client Exchanger
- if network == "https" {
- client = http.NewClient(timeout)
- } else {
- client = &dns.Client{Net: network, Timeout: timeout}
- }
- return &Client{Exchanger: client, Addresses: addresses}
+type resolver interface {
+ Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error)
}
-func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.Msg, error) {
- if len(address) == 0 {
- return nil, fmt.Errorf("no addresses to query")
+type client struct {
+ resolver resolver
+ address string
+}
+
+type mux struct{ clients []Client }
+
+// NewMux creates a new multiplexed client which queries all clients in parallel and returns the first successful
+// response.
+func NewMux(client ...Client) Client { return &mux{clients: client} }
+
+func (m *mux) Exchange(msg *dns.Msg) (*dns.Msg, error) {
+ if len(m.clients) == 0 {
+ return nil, fmt.Errorf("no clients to query")
}
- responses := make(chan *dns.Msg, len(address))
- errs := make(chan error, len(address))
+ responses := make(chan *dns.Msg, len(m.clients))
+ errs := make(chan error, len(m.clients))
var wg sync.WaitGroup
- for _, a := range address {
+ for _, c := range m.clients {
wg.Add(1)
- go func(addr string) {
+ go func(client Client) {
defer wg.Done()
- r, _, err := exchanger.Exchange(msg, addr)
+ r, err := client.Exchange(msg)
if err != nil {
- errs <- fmt.Errorf("resolver %s failed: %w", addr, err)
+ errs <- err
return
}
responses <- r
- }(a)
+ }(c)
}
go func() {
wg.Wait()
@@ -69,10 +73,23 @@ func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.M
return nil, <-errs
}
-// Exchange performs a synchronous DNS query. All addresses in Client c are queried in parallel and the first successful
-// response is returned.
-func (c *Client) Exchange(msg *dns.Msg) (*dns.Msg, error) {
- return multiExchange(c.Exchanger, msg, c.Addresses...)
+// NewClient creates a new Client for addr using config.
+func NewClient(addr string, config Config) Client {
+ var r resolver
+ if config.Network == "https" {
+ r = http.NewClient(config.Timeout)
+ } else {
+ r = &dns.Client{Net: config.Network, Timeout: config.Timeout}
+ }
+ return &client{resolver: r, address: addr}
+}
+
+func (c *client) Exchange(msg *dns.Msg) (*dns.Msg, error) {
+ r, _, err := c.resolver.Exchange(msg, c.address)
+ if err != nil {
+ return nil, fmt.Errorf("resolver %s failed: %w", c.address, err)
+ }
+ return r, err
}
// Answers returns all values in the answer section of DNS message msg.
diff --git a/dns/dnsutil/dnsutil_test.go b/dns/dnsutil/dnsutil_test.go
index 3ece7f3..3eca1a0 100644
--- a/dns/dnsutil/dnsutil_test.go
+++ b/dns/dnsutil/dnsutil_test.go
@@ -17,32 +17,30 @@ type response struct {
mu sync.Mutex
}
-type testExchanger struct {
- mu sync.RWMutex
- responses map[string]*response
+type testResolver struct {
+ mu sync.RWMutex
+ response *response
}
-func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} }
-
-func (e *testExchanger) setResponse(addr string, r *response) {
+func (e *testResolver) setResponse(r *response) {
e.mu.Lock()
defer e.mu.Unlock()
- e.responses[addr] = r
+ e.response = r
}
-func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
+func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) {
e.mu.RLock()
defer e.mu.RUnlock()
- r, ok := e.responses[addr]
- if !ok {
- panic("no such resolver: " + addr)
+ r := e.response
+ if r == nil {
+ panic("no response set")
}
if r.fail {
- return nil, 0, errors.New("error")
+ return nil, errors.New("error")
}
r.mu.Lock()
defer r.mu.Unlock()
- return r.answer, time.Second, nil
+ return r.answer, nil
}
func newA(name string, ttl uint32, ipAddr ...string) *dns.Msg {
@@ -130,17 +128,19 @@ func TestAnswers(t *testing.T) {
}
func TestExchange(t *testing.T) {
- addresses := []string{"addr1", "addr2"}
- exchanger := newTestExchanger()
+ resolver1 := &testResolver{}
+ resolver2 := &testResolver{}
// First responding resolver returns answer
answer1 := newA("example.com.", 60, "192.0.2.1")
answer2 := newA("example.com.", 60, "192.0.2.2")
r1 := response{answer: answer1}
r1.mu.Lock() // Locking first resolver so that second wins
- exchanger.setResponse(addresses[0], &r1)
- exchanger.setResponse(addresses[1], &response{answer: answer2})
- r, err := multiExchange(exchanger, &dns.Msg{}, addresses...)
+ resolver1.setResponse(&r1)
+ resolver2.setResponse(&response{answer: answer2})
+
+ mux := NewMux(resolver1, resolver2)
+ r, err := mux.Exchange(&dns.Msg{})
if err != nil {
t.Fatal(err)
}
@@ -150,9 +150,9 @@ func TestExchange(t *testing.T) {
r1.mu.Unlock()
// All resolvers fail
- exchanger.setResponse(addresses[0], &response{fail: true})
- exchanger.setResponse(addresses[1], &response{fail: true})
- _, err = multiExchange(exchanger, &dns.Msg{}, addresses...)
+ resolver1.setResponse(&response{fail: true})
+ resolver2.setResponse(&response{fail: true})
+ _, err = mux.Exchange(&dns.Msg{})
if err == nil {
t.Errorf("got %s, want error", err)
}
diff --git a/dns/proxy.go b/dns/proxy.go
index 6241edd..32f6907 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -38,12 +38,12 @@ type Proxy struct {
cache *cache.Cache
logger *sql.Logger
server *dns.Server
- client *dnsutil.Client
+ client dnsutil.Client
mu sync.RWMutex
}
// NewProxy creates a new DNS proxy.
-func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger *sql.Logger) (*Proxy, error) {
+func NewProxy(cache *cache.Cache, client dnsutil.Client, logger *sql.Logger) (*Proxy, error) {
return &Proxy{
logger: logger,
cache: cache,
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index 7eef51d..1b0431a 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -8,11 +8,9 @@ import (
"reflect"
"sync"
"testing"
- "time"
"github.com/miekg/dns"
"github.com/mpolden/zdns/cache"
- "github.com/mpolden/zdns/dns/dnsutil"
)
func init() {
@@ -41,30 +39,25 @@ type response struct {
fail bool
}
-type testExchanger struct {
- mu sync.RWMutex
- responses map[string]*response
+type testResolver struct {
+ mu sync.RWMutex
+ response *response
}
-func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} }
-
-func (e *testExchanger) setResponse(resolver string, response *response) {
+func (e *testResolver) setResponse(response *response) {
e.mu.Lock()
defer e.mu.Unlock()
- e.responses[resolver] = response
+ e.response = response
}
-func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
+func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) {
e.mu.RLock()
defer e.mu.RUnlock()
- r, ok := e.responses[addr]
- if !ok {
- panic("no such resolver: " + addr)
- }
- if r.fail {
- return nil, 0, fmt.Errorf("%s SERVFAIL", addr)
+ r := e.response
+ if r == nil || r.fail {
+ return nil, fmt.Errorf("SERVFAIL")
}
- return r.answer, time.Second, nil
+ return r.answer, nil
}
func testProxy(t *testing.T) *Proxy {
@@ -157,51 +150,34 @@ func TestProxy(t *testing.T) {
assertRR(t, p, &m, "::")
}
-func TestProxyWithResolvers(t *testing.T) {
+func TestProxyWithResolver(t *testing.T) {
p := testProxy(t)
- exchanger := newTestExchanger()
- p.client = &dnsutil.Client{Exchanger: exchanger}
+ r := &testResolver{}
+ p.client = r
defer p.Close()
- // No resolvers
+ // No response
assertFailure(t, p, TypeA, "host1")
- // First and only resolver responds succesfully
- p.client.Addresses = []string{"resolver1"}
+ // Responds succesfully
reply := ReplyA("host1", net.ParseIP("192.0.2.1"))
m := dns.Msg{}
m.Id = dns.Id()
m.SetQuestion("host1.", dns.TypeA)
m.Answer = reply.rr
response1 := &response{answer: &m}
- exchanger.setResponse("resolver1", response1)
+ r.setResponse(response1)
assertRR(t, p, &m, "192.0.2.1")
- // First and only resolver fails
+ // Resolver fails
response1.fail = true
assertFailure(t, p, TypeA, "host1")
-
- // First resolver fails, but second succeeds
- reply = ReplyA("host1", net.ParseIP("192.0.2.2"))
- p.client.Addresses = []string{"resolver1", "resolver2"}
- m = dns.Msg{}
- m.Id = dns.Id()
- m.SetQuestion("host1.", dns.TypeA)
- m.Answer = reply.rr
- response2 := &response{answer: &m}
- exchanger.setResponse("resolver2", response2)
- assertRR(t, p, &m, "192.0.2.2")
-
- // All resolvers fail
- response2.fail = true
- assertFailure(t, p, TypeA, "host1")
}
func TestProxyWithCache(t *testing.T) {
p := testProxy(t)
p.cache = cache.New(10, nil)
- exchanger := newTestExchanger()
- p.client = &dnsutil.Client{Exchanger: exchanger}
- p.client.Addresses = []string{"resolver1"}
+ r := &testResolver{}
+ p.client = r
defer p.Close()
reply := ReplyA("host1", net.ParseIP("192.0.2.1"))
@@ -209,7 +185,7 @@ func TestProxyWithCache(t *testing.T) {
m.Id = dns.Id()
m.SetQuestion("host1.", dns.TypeA)
m.Answer = reply.rr
- exchanger.setResponse("resolver1", &response{answer: &m})
+ r.setResponse(&response{answer: &m})
assertRR(t, p, &m, "192.0.2.1")
k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET)