diff options
author | Martin Polden <mpolden@mpolden.no> | 2020-05-09 12:14:29 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2020-05-09 12:14:29 +0200 |
commit | a0499585a5b66d2a7eb2d1c3b8b7a9346307e9fa (patch) | |
tree | 74c0df116c0cdc928ee1a10369fd6286833a913a | |
parent | 17887b5547d972eff47d20f586d6d45cf7d3374a (diff) |
Refactor multiplexed client
-rw-r--r-- | cache/cache.go | 8 | ||||
-rw-r--r-- | cache/cache_test.go | 28 | ||||
-rw-r--r-- | cmd/zdns/main.go | 12 | ||||
-rw-r--r-- | dns/dnsutil/dnsutil.go | 77 | ||||
-rw-r--r-- | dns/dnsutil/dnsutil_test.go | 42 | ||||
-rw-r--r-- | dns/proxy.go | 4 | ||||
-rw-r--r-- | dns/proxy_test.go | 64 |
7 files changed, 117 insertions, 118 deletions
diff --git a/cache/cache.go b/cache/cache.go index f26384d..dc2b831 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -24,7 +24,7 @@ type Backend interface { // Cache is a cache of DNS messages. type Cache struct { - client *dnsutil.Client + client dnsutil.Client backend Backend capacity int values map[uint32]Value @@ -116,16 +116,16 @@ func Unpack(value string) (Value, error) { // // - All cache write operations will be forward to the backend. // - The backed will be used to pre-populate the cache. -func New(capacity int, client *dnsutil.Client) *Cache { +func New(capacity int, client dnsutil.Client) *Cache { return NewWithBackend(capacity, client, nil) } // NewWithBackend creates a new cache that forwards entries to backend. -func NewWithBackend(capacity int, client *dnsutil.Client, backend Backend) *Cache { +func NewWithBackend(capacity int, client dnsutil.Client, backend Backend) *Cache { return newCache(capacity, client, backend, time.Now) } -func newCache(capacity int, client *dnsutil.Client, backend Backend, now func() time.Time) *Cache { +func newCache(capacity int, client dnsutil.Client, backend Backend, now func() time.Time) *Cache { if capacity < 0 { capacity = 0 } diff --git a/cache/cache_test.go b/cache/cache_test.go index 1a25721..7a23d32 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -14,32 +14,32 @@ import ( var testMsg *dns.Msg = newA("example.com.", 60, net.ParseIP("192.0.2.1")) -type testExchanger struct { +type testClient struct { mu sync.RWMutex answers chan *dns.Msg } -func newTestExchanger() *testExchanger { return &testExchanger{answers: make(chan *dns.Msg, 100)} } +func newTestClient() *testClient { return &testClient{answers: make(chan *dns.Msg, 100)} } -func (e *testExchanger) setAnswer(answer *dns.Msg) { +func (e *testClient) setAnswer(answer *dns.Msg) { e.mu.Lock() defer e.mu.Unlock() e.answers <- answer } -func (e *testExchanger) reset() { +func (e *testClient) reset() { e.mu.Lock() defer e.mu.Unlock() e.answers = make(chan *dns.Msg, 100) } -func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { +func (e *testClient) Exchange(msg *dns.Msg) (*dns.Msg, error) { e.mu.RLock() defer e.mu.RUnlock() if len(e.answers) == 0 { - return nil, 0, fmt.Errorf("no answer pending") + return nil, fmt.Errorf("no answer pending") } - return <-e.answers, time.Second, nil + return <-e.answers, nil } type testBackend struct { @@ -252,8 +252,7 @@ func TestReset(t *testing.T) { } func TestCachePrefetch(t *testing.T) { - exchanger := newTestExchanger() - client := &dnsutil.Client{Exchanger: exchanger, Addresses: []string{"resolver"}} + client := newTestClient() now := time.Now() c := newCache(10, client, nil, func() time.Time { return now }) var tests = []struct { @@ -279,8 +278,8 @@ func TestCachePrefetch(t *testing.T) { copy := testMsg.Copy() copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer) copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds()) - exchanger.reset() - exchanger.setAnswer(copy) + client.reset() + client.setAnswer(copy) // Add new value now c.now = func() time.Time { return now } @@ -308,8 +307,7 @@ func TestCachePrefetch(t *testing.T) { } func TestCacheEvictAndUpdate(t *testing.T) { - exchanger := newTestExchanger() - client := &dnsutil.Client{Exchanger: exchanger, Addresses: []string{"resolver"}} + client := newTestClient() now := time.Now() c := newCache(10, client, nil, func() time.Time { return now }) @@ -319,10 +317,10 @@ func TestCacheEvictAndUpdate(t *testing.T) { // Initial prefetched answer can no longer be cached copy := testMsg.Copy() copy.Answer[0].(*dns.A).Hdr.Ttl = 0 - exchanger.setAnswer(copy) + client.setAnswer(copy) copy = testMsg.Copy() copy.Answer[0].(*dns.A).Hdr.Ttl = 30 - exchanger.setAnswer(copy) + client.setAnswer(copy) // Advance time so that msg is now considered expired. Query to trigger prefetch c.now = func() time.Time { return now.Add(61 * time.Second) } diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 5af70c6..31407ee 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -101,11 +101,19 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) } // DNS client - dnsClient := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) + dnsConfig := dnsutil.Config{ + Network: config.Resolver.Protocol, + Timeout: config.Resolver.Timeout, + } + dnsClients := make([]dnsutil.Client, 0, len(config.DNS.Resolvers)) + for _, addr := range config.DNS.Resolvers { + dnsClients = append(dnsClients, dnsutil.NewClient(addr, dnsConfig)) + } + dnsClient := dnsutil.NewMux(dnsClients...) // Cache var dnsCache *cache.Cache - var cacheDNS *dnsutil.Client + var cacheDNS dnsutil.Client if config.DNS.CachePrefetch { cacheDNS = dnsClient } diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go index 9cf75b6..3b989c9 100644 --- a/dns/dnsutil/dnsutil.go +++ b/dns/dnsutil/dnsutil.go @@ -17,46 +17,50 @@ var ( RcodeToString = dns.RcodeToString ) -// Exchanger is the interface that wraps the Exchange method of a DNS client. -type Exchanger interface { - Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) +// Client is the interface of a DNS client. +type Client interface { + Exchange(*dns.Msg) (*dns.Msg, error) } -// Client wraps a DNS client and a list of server addresses. -type Client struct { - Exchanger Exchanger - Addresses []string +// Config is a structure used to configure a DNS client. +type Config struct { + Network string + Timeout time.Duration } -// NewClient creates a new Client using the named network and addresses. -func NewClient(network string, timeout time.Duration, addresses ...string) *Client { - var client Exchanger - if network == "https" { - client = http.NewClient(timeout) - } else { - client = &dns.Client{Net: network, Timeout: timeout} - } - return &Client{Exchanger: client, Addresses: addresses} +type resolver interface { + Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) } -func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.Msg, error) { - if len(address) == 0 { - return nil, fmt.Errorf("no addresses to query") +type client struct { + resolver resolver + address string +} + +type mux struct{ clients []Client } + +// NewMux creates a new multiplexed client which queries all clients in parallel and returns the first successful +// response. +func NewMux(client ...Client) Client { return &mux{clients: client} } + +func (m *mux) Exchange(msg *dns.Msg) (*dns.Msg, error) { + if len(m.clients) == 0 { + return nil, fmt.Errorf("no clients to query") } - responses := make(chan *dns.Msg, len(address)) - errs := make(chan error, len(address)) + responses := make(chan *dns.Msg, len(m.clients)) + errs := make(chan error, len(m.clients)) var wg sync.WaitGroup - for _, a := range address { + for _, c := range m.clients { wg.Add(1) - go func(addr string) { + go func(client Client) { defer wg.Done() - r, _, err := exchanger.Exchange(msg, addr) + r, err := client.Exchange(msg) if err != nil { - errs <- fmt.Errorf("resolver %s failed: %w", addr, err) + errs <- err return } responses <- r - }(a) + }(c) } go func() { wg.Wait() @@ -69,10 +73,23 @@ func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.M return nil, <-errs } -// Exchange performs a synchronous DNS query. All addresses in Client c are queried in parallel and the first successful -// response is returned. -func (c *Client) Exchange(msg *dns.Msg) (*dns.Msg, error) { - return multiExchange(c.Exchanger, msg, c.Addresses...) +// NewClient creates a new Client for addr using config. +func NewClient(addr string, config Config) Client { + var r resolver + if config.Network == "https" { + r = http.NewClient(config.Timeout) + } else { + r = &dns.Client{Net: config.Network, Timeout: config.Timeout} + } + return &client{resolver: r, address: addr} +} + +func (c *client) Exchange(msg *dns.Msg) (*dns.Msg, error) { + r, _, err := c.resolver.Exchange(msg, c.address) + if err != nil { + return nil, fmt.Errorf("resolver %s failed: %w", c.address, err) + } + return r, err } // Answers returns all values in the answer section of DNS message msg. diff --git a/dns/dnsutil/dnsutil_test.go b/dns/dnsutil/dnsutil_test.go index 3ece7f3..3eca1a0 100644 --- a/dns/dnsutil/dnsutil_test.go +++ b/dns/dnsutil/dnsutil_test.go @@ -17,32 +17,30 @@ type response struct { mu sync.Mutex } -type testExchanger struct { - mu sync.RWMutex - responses map[string]*response +type testResolver struct { + mu sync.RWMutex + response *response } -func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} } - -func (e *testExchanger) setResponse(addr string, r *response) { +func (e *testResolver) setResponse(r *response) { e.mu.Lock() defer e.mu.Unlock() - e.responses[addr] = r + e.response = r } -func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { +func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) { e.mu.RLock() defer e.mu.RUnlock() - r, ok := e.responses[addr] - if !ok { - panic("no such resolver: " + addr) + r := e.response + if r == nil { + panic("no response set") } if r.fail { - return nil, 0, errors.New("error") + return nil, errors.New("error") } r.mu.Lock() defer r.mu.Unlock() - return r.answer, time.Second, nil + return r.answer, nil } func newA(name string, ttl uint32, ipAddr ...string) *dns.Msg { @@ -130,17 +128,19 @@ func TestAnswers(t *testing.T) { } func TestExchange(t *testing.T) { - addresses := []string{"addr1", "addr2"} - exchanger := newTestExchanger() + resolver1 := &testResolver{} + resolver2 := &testResolver{} // First responding resolver returns answer answer1 := newA("example.com.", 60, "192.0.2.1") answer2 := newA("example.com.", 60, "192.0.2.2") r1 := response{answer: answer1} r1.mu.Lock() // Locking first resolver so that second wins - exchanger.setResponse(addresses[0], &r1) - exchanger.setResponse(addresses[1], &response{answer: answer2}) - r, err := multiExchange(exchanger, &dns.Msg{}, addresses...) + resolver1.setResponse(&r1) + resolver2.setResponse(&response{answer: answer2}) + + mux := NewMux(resolver1, resolver2) + r, err := mux.Exchange(&dns.Msg{}) if err != nil { t.Fatal(err) } @@ -150,9 +150,9 @@ func TestExchange(t *testing.T) { r1.mu.Unlock() // All resolvers fail - exchanger.setResponse(addresses[0], &response{fail: true}) - exchanger.setResponse(addresses[1], &response{fail: true}) - _, err = multiExchange(exchanger, &dns.Msg{}, addresses...) + resolver1.setResponse(&response{fail: true}) + resolver2.setResponse(&response{fail: true}) + _, err = mux.Exchange(&dns.Msg{}) if err == nil { t.Errorf("got %s, want error", err) } diff --git a/dns/proxy.go b/dns/proxy.go index 6241edd..32f6907 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -38,12 +38,12 @@ type Proxy struct { cache *cache.Cache logger *sql.Logger server *dns.Server - client *dnsutil.Client + client dnsutil.Client mu sync.RWMutex } // NewProxy creates a new DNS proxy. -func NewProxy(cache *cache.Cache, client *dnsutil.Client, logger *sql.Logger) (*Proxy, error) { +func NewProxy(cache *cache.Cache, client dnsutil.Client, logger *sql.Logger) (*Proxy, error) { return &Proxy{ logger: logger, cache: cache, diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 7eef51d..1b0431a 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -8,11 +8,9 @@ import ( "reflect" "sync" "testing" - "time" "github.com/miekg/dns" "github.com/mpolden/zdns/cache" - "github.com/mpolden/zdns/dns/dnsutil" ) func init() { @@ -41,30 +39,25 @@ type response struct { fail bool } -type testExchanger struct { - mu sync.RWMutex - responses map[string]*response +type testResolver struct { + mu sync.RWMutex + response *response } -func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} } - -func (e *testExchanger) setResponse(resolver string, response *response) { +func (e *testResolver) setResponse(response *response) { e.mu.Lock() defer e.mu.Unlock() - e.responses[resolver] = response + e.response = response } -func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { +func (e *testResolver) Exchange(msg *dns.Msg) (*dns.Msg, error) { e.mu.RLock() defer e.mu.RUnlock() - r, ok := e.responses[addr] - if !ok { - panic("no such resolver: " + addr) - } - if r.fail { - return nil, 0, fmt.Errorf("%s SERVFAIL", addr) + r := e.response + if r == nil || r.fail { + return nil, fmt.Errorf("SERVFAIL") } - return r.answer, time.Second, nil + return r.answer, nil } func testProxy(t *testing.T) *Proxy { @@ -157,51 +150,34 @@ func TestProxy(t *testing.T) { assertRR(t, p, &m, "::") } -func TestProxyWithResolvers(t *testing.T) { +func TestProxyWithResolver(t *testing.T) { p := testProxy(t) - exchanger := newTestExchanger() - p.client = &dnsutil.Client{Exchanger: exchanger} + r := &testResolver{} + p.client = r defer p.Close() - // No resolvers + // No response assertFailure(t, p, TypeA, "host1") - // First and only resolver responds succesfully - p.client.Addresses = []string{"resolver1"} + // Responds succesfully reply := ReplyA("host1", net.ParseIP("192.0.2.1")) m := dns.Msg{} m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr response1 := &response{answer: &m} - exchanger.setResponse("resolver1", response1) + r.setResponse(response1) assertRR(t, p, &m, "192.0.2.1") - // First and only resolver fails + // Resolver fails response1.fail = true assertFailure(t, p, TypeA, "host1") - - // First resolver fails, but second succeeds - reply = ReplyA("host1", net.ParseIP("192.0.2.2")) - p.client.Addresses = []string{"resolver1", "resolver2"} - m = dns.Msg{} - m.Id = dns.Id() - m.SetQuestion("host1.", dns.TypeA) - m.Answer = reply.rr - response2 := &response{answer: &m} - exchanger.setResponse("resolver2", response2) - assertRR(t, p, &m, "192.0.2.2") - - // All resolvers fail - response2.fail = true - assertFailure(t, p, TypeA, "host1") } func TestProxyWithCache(t *testing.T) { p := testProxy(t) p.cache = cache.New(10, nil) - exchanger := newTestExchanger() - p.client = &dnsutil.Client{Exchanger: exchanger} - p.client.Addresses = []string{"resolver1"} + r := &testResolver{} + p.client = r defer p.Close() reply := ReplyA("host1", net.ParseIP("192.0.2.1")) @@ -209,7 +185,7 @@ func TestProxyWithCache(t *testing.T) { m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - exchanger.setResponse("resolver1", &response{answer: &m}) + r.setResponse(&response{answer: &m}) assertRR(t, p, &m, "192.0.2.1") k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET) |