aboutsummaryrefslogtreecommitdiffstats
path: root/dns/proxy.go
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-06-08 17:40:07 +0200
committerMartin Polden <mpolden@mpolden.no>2019-06-08 17:50:22 +0200
commitfa475dcc8a744739e7c40d8ba2c57385175690c2 (patch)
tree6c473d5ce817c9d326f0c198e2666256b8c60942 /dns/proxy.go
parent6544d1d8ade6331c8c04f75439b13eec9b7fbe98 (diff)
Default to tcp-tls for upstream
Diffstat (limited to 'dns/proxy.go')
-rw-r--r--dns/proxy.go24
1 files changed, 14 insertions, 10 deletions
diff --git a/dns/proxy.go b/dns/proxy.go
index 4779b66..3994bb8 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -1,6 +1,7 @@
package dns
import (
+ "log"
"net"
"strings"
"time"
@@ -28,8 +29,9 @@ type Handler func(*Request) *Reply
// Proxy represents a DNS proxy.
type Proxy struct {
- handler Handler
- resolvers []string
+ Handler Handler
+ Resolvers []string
+ logger *log.Logger
server *dns.Server
client client
}
@@ -39,11 +41,10 @@ type client interface {
}
// NewProxy creates a new DNS proxy.
-func NewProxy(handler Handler, resolvers []string, timeout time.Duration) *Proxy {
+func NewProxy(logger *log.Logger, network string, timeout time.Duration) *Proxy {
return &Proxy{
- handler: handler,
- resolvers: resolvers,
- client: &dns.Client{Timeout: timeout},
+ logger: logger,
+ client: &dns.Client{Net: network, Timeout: timeout},
}
}
@@ -83,10 +84,10 @@ func (r *Reply) String() string {
}
func (p *Proxy) reply(r *dns.Msg) *dns.Msg {
- if p.handler == nil || len(r.Question) != 1 {
+ if p.Handler == nil || len(r.Question) != 1 {
return nil
}
- reply := p.handler(&Request{
+ reply := p.Handler(&Request{
Name: r.Question[0].Name,
Type: r.Question[0].Qtype,
})
@@ -115,10 +116,13 @@ func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
_ = w.WriteMsg(reply) // TODO: Decide whether to handle write errors
return
}
- for i, resolver := range p.resolvers {
+ for i, resolver := range p.Resolvers {
rr, _, err := p.client.Exchange(r, resolver)
if err != nil {
- if i == len(p.resolvers)-1 {
+ if p.logger != nil {
+ p.logger.Printf("resolver %s failed: %s", resolver, err)
+ }
+ if i == len(p.Resolvers)-1 {
break
} else {
continue