diff options
author | Martin Polden <mpolden@mpolden.no> | 2020-05-09 12:14:29 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2020-05-09 12:14:29 +0200 |
commit | a0499585a5b66d2a7eb2d1c3b8b7a9346307e9fa (patch) | |
tree | 74c0df116c0cdc928ee1a10369fd6286833a913a /dns | |
parent | 17887b5547d972eff47d20f586d6d45cf7d3374a (diff) |
Refactor multiplexed client
Diffstat (limited to 'dns')
-rw-r--r-- | dns/dnsutil/dnsutil.go | 77 | ||||
-rw-r--r-- | dns/dnsutil/dnsutil_test.go | 42 | ||||
-rw-r--r-- | dns/proxy.go | 4 | ||||
-rw-r--r-- | dns/proxy_test.go | 64 |
4 files changed, 90 insertions, 97 deletions
diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go index 9cf75b6..3b989c9 100644 --- a/dns/dnsutil/dnsutil.go +++ b/dns/dnsutil/dnsutil.go @@ -17,46 +17,50 @@ var ( RcodeToString = dns.RcodeToString ) -// Exchanger is the interface that wraps the Exchange method of a DNS client. -type Exchanger interface { - Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) +// Client is the interface of a DNS client. +type Client interface { + Exchange(*dns.Msg) (*dns.Msg, error) } -// Client wraps a DNS client and a list of server addresses. -type Client struct { - Exchanger Exchanger - Addresses []string +// Config is a structure used to configure a DNS client. +type Config struct { + Network string + Timeout time.Duration } -// NewClient creates a new Client using the named network and addresses. -func NewClient(network string, timeout time.Duration, addresses ...string) *Client { - var client Exchanger - if network == "https" { - client = http.NewClient(timeout) - } else { - client = &dns.Client{Net: network, Timeout: timeout} - } - return &Client{Exchanger: client, Addresses: addresses} +type resolver interface { + Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) } -func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.Msg, error) { - if len(address) == 0 { - return nil, fmt.Errorf("no addresses to query") +type client struct { + resolver resolver + address string +} + +type mux struct{ clients []Client } + +// NewMux creates a new multiplexed client which queries all clients in parallel and returns the first successful +// response. +func NewMux(client ...Client) Client { return &mux{clients: client} } + +func (m *mux) Exchange(msg *dns.Msg) (*dns.Msg, error) { + if len(m.clients) == 0 { + return nil, fmt.Errorf("no clients to query") } - responses := make(chan *dns.Msg, len(address)) - errs := make(chan error, len(address)) + responses := make(chan *dns.Msg, len(m.clients)) + errs := make(chan error, len(m.clients)) var wg sync.WaitGroup - for _, a := range address { + for _, c := range m.clients { wg.Add(1) - go func(addr string) { + go func(client Client) { defer wg.Done() - r, _, err := exchanger.Exchange(msg, addr) + r, err := client.Exchange(msg) if err != nil { - errs <- fmt.Errorf("resolver %s failed: %w", addr, err) + errs <- err return } responses <- r - }(a) + }(c) } go func() { wg.Wait() @@ -69,10 +73,23 @@ func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.M return nil, <-errs } -// Exchange performs a synchronous DNS query. All addresses in Client c are queried in parallel and the first successful -// response is returned. -func (c *Client) Exchange(msg *dns.Msg) (*dns.Msg, error) { - return multiExchange(c.Exchanger, msg, c.Addresses...) +// NewClient creates a new Client for addr using config. +func NewClient(addr string, config Config) Client { + var r resolver + if config.Network == "https" { + r = http.NewClient(config.Timeout) + } else { + r = &dns.Client{Net: config.Network, Timeout: config.Timeout} + } + return &client{resolver: r, address: addr} +} + +func (c *client) Exchange(msg *dns.Msg) (*dns.Msg, error) { + r, _, err := c.resolver.Exchange(msg, c.address) + if err != nil { + return nil, fmt.Errorf("resolver %s failed: %w", c.address, err) + } + return r, err } // Answers returns all values in the answer section of DNS message msg. diff --git a/dns/dnsutil/dnsutil_test.go b/dns/dnsutil/dnsutil_test.go index 3ece7f3..3eca1a0 100644 --- a/dns/dnsutil/dnsutil_test.go +++ b/dns/dnsutil/dnsutil_test.go @@ -17,32 +17,30 @@ type response struct { mu sync.Mutex } -type testExchanger struct { - mu sync.RWMutex - responses map[string]*response +type testResolver struct { + mu sync.RWMutex + response *response } -func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} } - -func (e *testExchanger) setResponse(addr string, r *response) { +func (e *testResolver) setResponse(r *response) { e.mu.Lock() defer e.mu.Unlock() - e.responses[addr] = r + e.response = r } -func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { +func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) { e.mu.RLock() defer e.mu.RUnlock() - r, ok := e.responses[addr] - if !ok { - panic("no such resolver: " + addr) + r := e.response + if r == nil { + panic("no response set") } if r.fail { - return nil, 0, errors.New("error") + return nil, errors.New("error") } r.mu.Lock() defer r.mu.Unlock() - return r.answer, time.Second, nil + return r.answer, nil } func newA(name string, ttl uint32, ipAddr ...string) *dns.Msg { @@ -130,17 +128,19 @@ func TestAnswers(t *testing.T) { } func TestExchange(t *testing.T) { - addresses := []string{"addr1", "addr2"} - exchanger := newTestExchanger() + resolver1 := &testResolver{} + resolver2 := &testResolver{} // First responding resolver returns answer answer1 := newA("example.com.", 60, "192.0.2.1") answer2 := newA("example.com.", 60, "192.0.2.2") r1 := response{answer: answer1} r1.mu.Lock() // Locking first resolver so that second wins - exchanger.setResponse(addresses[0], &r1) - exchanger.setResponse(addresses[1], &response{answer: answer2}) - r, err := multiExchange(exchanger, &dns.Msg{}, addresses...) + resolver1.setResponse(&r1) + resolver2.setResponse(&response{answer: answer2}) + + mux := NewMux(resolver1, resolver2) + r, err := mux.Exchange(&dns.Msg{}) if err != nil { t.Fatal(err) } @@ -150,9 +150,9 @@ func TestExchange(t *testing.T) { r1.mu.Unlock() // All resolvers fail - exchanger.setResponse(addresses[0], &response{fail: true}) - exchanger.setResponse(addresses[1], &response{fail: true}) - _, err = multiExchange(exchanger, &dns.Msg{}, addresses...) + resolver1.setResponse(&response{fail: true}) + resolver2.setResponse(&response{fail: true}) + _, err = mux.Exchange(&dns.Msg{}) if err == nil { t.Errorf("got %s, want error", err) } diff --git a/dns/proxy.go b/dns/proxy.go index 6241edd..32f6907 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -38,12 +38,12 @@ type Proxy struct { cache *cache.Cache logger *sql.Logger server *dns.Server - client *dnsutil.Client + client dnsutil.Client mu sync.RWMutex } // NewProxy creates a new DNS proxy. -func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger *sql.Logger) (*Proxy, error) { +func NewProxy(cache *cache.Cache, client dnsutil.Client, logger *sql.Logger) (*Proxy, error) { return &Proxy{ logger: logger, cache: cache, diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 7eef51d..1b0431a 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -8,11 +8,9 @@ import ( "reflect" "sync" "testing" - "time" "github.com/miekg/dns" "github.com/mpolden/zdns/cache" - "github.com/mpolden/zdns/dns/dnsutil" ) func init() { @@ -41,30 +39,25 @@ type response struct { fail bool } -type testExchanger struct { - mu sync.RWMutex - responses map[string]*response +type testResolver struct { + mu sync.RWMutex + response *response } -func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} } - -func (e *testExchanger) setResponse(resolver string, response *response) { +func (e *testResolver) setResponse(response *response) { e.mu.Lock() defer e.mu.Unlock() - e.responses[resolver] = response + e.response = response } -func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { +func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) { e.mu.RLock() defer e.mu.RUnlock() - r, ok := e.responses[addr] - if !ok { - panic("no such resolver: " + addr) - } - if r.fail { - return nil, 0, fmt.Errorf("%s SERVFAIL", addr) + r := e.response + if r == nil || r.fail { + return nil, fmt.Errorf("SERVFAIL") } - return r.answer, time.Second, nil + return r.answer, nil } func testProxy(t *testing.T) *Proxy { @@ -157,51 +150,34 @@ func TestProxy(t *testing.T) { assertRR(t, p, &m, "::") } -func TestProxyWithResolvers(t *testing.T) { +func TestProxyWithResolver(t *testing.T) { p := testProxy(t) - exchanger := newTestExchanger() - p.client = &dnsutil.Client{Exchanger: exchanger} + r := &testResolver{} + p.client = r defer p.Close() - // No resolvers + // No response assertFailure(t, p, TypeA, "host1") - // First and only resolver responds succesfully - p.client.Addresses = []string{"resolver1"} + // Responds succesfully reply := ReplyA("host1", net.ParseIP("192.0.2.1")) m := dns.Msg{} m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr response1 := &response{answer: &m} - exchanger.setResponse("resolver1", response1) + r.setResponse(response1) assertRR(t, p, &m, "192.0.2.1") - // First and only resolver fails + // Resolver fails response1.fail = true assertFailure(t, p, TypeA, "host1") - - // First resolver fails, but second succeeds - reply = ReplyA("host1", net.ParseIP("192.0.2.2")) - p.client.Addresses = []string{"resolver1", "resolver2"} - m = dns.Msg{} - m.Id = dns.Id() - m.SetQuestion("host1.", dns.TypeA) - m.Answer = reply.rr - response2 := &response{answer: &m} - exchanger.setResponse("resolver2", response2) - assertRR(t, p, &m, "192.0.2.2") - - // All resolvers fail - response2.fail = true - assertFailure(t, p, TypeA, "host1") } func TestProxyWithCache(t *testing.T) { p := testProxy(t) p.cache = cache.New(10, nil) - exchanger := newTestExchanger() - p.client = &dnsutil.Client{Exchanger: exchanger} - p.client.Addresses = []string{"resolver1"} + r := &testResolver{} + p.client = r defer p.Close() reply := ReplyA("host1", net.ParseIP("192.0.2.1")) @@ -209,7 +185,7 @@ func TestProxyWithCache(t *testing.T) { m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - exchanger.setResponse("resolver1", &response{answer: &m}) + r.setResponse(&response{answer: &m}) assertRR(t, p, &m, "192.0.2.1") k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET) |