diff options
Diffstat (limited to 'dns/dnsutil/dnsutil.go')
-rw-r--r-- | dns/dnsutil/dnsutil.go | 77 |
1 files changed, 47 insertions, 30 deletions
diff --git a/dns/dnsutil/dnsutil.go b/dns/dnsutil/dnsutil.go index 9cf75b6..3b989c9 100644 --- a/dns/dnsutil/dnsutil.go +++ b/dns/dnsutil/dnsutil.go @@ -17,46 +17,50 @@ var ( RcodeToString = dns.RcodeToString ) -// Exchanger is the interface that wraps the Exchange method of a DNS client. -type Exchanger interface { - Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) +// Client is the interface of a DNS client. +type Client interface { + Exchange(*dns.Msg) (*dns.Msg, error) } -// Client wraps a DNS client and a list of server addresses. -type Client struct { - Exchanger Exchanger - Addresses []string +// Config is a structure used to configure a DNS client. +type Config struct { + Network string + Timeout time.Duration } -// NewClient creates a new Client using the named network and addresses. -func NewClient(network string, timeout time.Duration, addresses ...string) *Client { - var client Exchanger - if network == "https" { - client = http.NewClient(timeout) - } else { - client = &dns.Client{Net: network, Timeout: timeout} - } - return &Client{Exchanger: client, Addresses: addresses} +type resolver interface { + Exchange(*dns.Msg, string) (*dns.Msg, time.Duration, error) } -func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.Msg, error) { - if len(address) == 0 { - return nil, fmt.Errorf("no addresses to query") +type client struct { + resolver resolver + address string +} + +type mux struct{ clients []Client } + +// NewMux creates a new multiplexed client which queries all clients in parallel and returns the first successful +// response. +func NewMux(client ...Client) Client { return &mux{clients: client} } + +func (m *mux) Exchange(msg *dns.Msg) (*dns.Msg, error) { + if len(m.clients) == 0 { + return nil, fmt.Errorf("no clients to query") } - responses := make(chan *dns.Msg, len(address)) - errs := make(chan error, len(address)) + responses := make(chan *dns.Msg, len(m.clients)) + errs := make(chan error, len(m.clients)) var wg sync.WaitGroup - for _, a := range address { + for _, c := range m.clients { wg.Add(1) - go func(addr string) { + go func(client Client) { defer wg.Done() - r, _, err := exchanger.Exchange(msg, addr) + r, err := client.Exchange(msg) if err != nil { - errs <- fmt.Errorf("resolver %s failed: %w", addr, err) + errs <- err return } responses <- r - }(a) + }(c) } go func() { wg.Wait() @@ -69,10 +73,23 @@ func multiExchange(exchanger Exchanger, msg *dns.Msg, address ...string) (*dns.M return nil, <-errs } -// Exchange performs a synchronous DNS query. All addresses in Client c are queried in parallel and the first successful -// response is returned. -func (c *Client) Exchange(msg *dns.Msg) (*dns.Msg, error) { - return multiExchange(c.Exchanger, msg, c.Addresses...) +// NewClient creates a new Client for addr using config. +func NewClient(addr string, config Config) Client { + var r resolver + if config.Network == "https" { + r = http.NewClient(config.Timeout) + } else { + r = &dns.Client{Net: config.Network, Timeout: config.Timeout} + } + return &client{resolver: r, address: addr} +} + +func (c *client) Exchange(msg *dns.Msg) (*dns.Msg, error) { + r, _, err := c.resolver.Exchange(msg, c.address) + if err != nil { + return nil, fmt.Errorf("resolver %s failed: %w", c.address, err) + } + return r, err } // Answers returns all values in the answer section of DNS message msg. |