aboutsummaryrefslogtreecommitdiffstats
path: root/server.go
blob: 73044b09187553f3dc796abfd59da611be30cf03 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
package zdns

import (
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/url"
	"os"
	"os/signal"
	"sync"
	"syscall"
	"time"

	"github.com/mpolden/zdns/dns"
	"github.com/mpolden/zdns/hosts"
)

// A Server defines parameters for running a DNS server.
type Server struct {
	Config  Config
	Logger  *log.Logger
	proxy   *dns.Proxy
	matcher *hosts.Matcher
	ticker  *time.Ticker
	done    chan bool
	signal  chan os.Signal
	mu      sync.RWMutex
}

// NewServer returns a new server configured according to config.
func NewServer(config Config) (*Server, error) {
	server := &Server{
		Config: config,
		signal: make(chan os.Signal, 1),
		done:   make(chan bool, 1),
	}
	if config.Filter.RefreshInterval.Duration > 0 {
		server.ticker = time.NewTicker(config.Filter.RefreshInterval.Duration)
		go server.reloadHosts()
	}
	signal.Notify(server.signal)
	go server.readSignal()
	proxy := dns.NewProxy(server.handler, config.Resolvers, config.Resolver.Timeout.Duration)
	server.proxy = proxy
	return server, nil
}

func readHosts(name string) (hosts.Hosts, error) {
	url, err := url.Parse(name)
	if err != nil {
		return nil, err
	}
	var rc io.ReadCloser
	switch url.Scheme {
	case "file":
		f, err := os.Open(url.Path)
		if err != nil {
			return nil, err
		}
		rc = f
	case "http", "https":
		client := http.Client{Timeout: 10 * time.Second}
		res, err := client.Get(url.String())
		if err != nil {
			return nil, err
		}
		rc = res.Body
	default:
		return nil, fmt.Errorf("%s: invalid scheme: %s", url, url.Scheme)
	}
	defer rc.Close()
	return hosts.Parse(rc)
}

func nonFqdn(s string) string {
	sz := len(s)
	if sz > 0 && s[sz-1:] == "." {
		return s[:sz-1]
	}
	return s
}

func (s *Server) logf(format string, v ...interface{}) {
	if s.Logger != nil {
		s.Logger.Printf(format, v...)
	}
}

func (s *Server) readSignal() {
	for {
		select {
		case <-s.done:
			signal.Stop(s.signal)
			return
		case sig := <-s.signal:
			switch sig {
			case syscall.SIGHUP:
				s.logf("received signal %s: reloading filters", sig)
				s.loadHosts()
			case syscall.SIGTERM, syscall.SIGINT:
				s.logf("received signal %s: shutting down", sig)
				s.Close()
			default:
				s.logf("received signal %s: ignoring", sig)
			}
		}
	}
}

func (s *Server) reloadHosts() {
	for {
		select {
		case <-s.done:
			s.ticker.Stop()
			return
		case <-s.ticker.C:
			s.loadHosts()
		}
	}
}

func (s *Server) loadHosts() {
	var hs []hosts.Hosts
	var size int
	for _, f := range s.Config.Filters {
		h, err := readHosts(f.URL)
		if err != nil {
			s.logf("failed to read hosts from %s: %s", f.URL, err)
			continue
		}
		if f.Reject {
			hs = append(hs, h)
			s.logf("loaded %d hosts from %s", len(h), f.URL)
			size += len(h)
		} else {
			var removed int
			for hostToRemove := range h {
				for _, h := range hs {
					if _, ok := h.Get(hostToRemove); ok {
						removed++
						h.Del(hostToRemove)
					}
				}
			}
			size -= removed
			if removed > 0 {
				s.logf("removed %d hosts from %s", len(h), f.URL)
			}
		}
	}
	m := hosts.NewMatcher(hs...)
	s.mu.Lock()
	defer s.mu.Unlock()
	s.matcher = m
	s.logf("loaded %d hosts in total", size)
}

// Close terminates all active operations and shuts down the DNS server.
func (s *Server) Close() {
	s.done <- true
	s.done <- true
	if err := s.proxy.Close(); err != nil {
		s.logf("error during close: %s", err)
	}
}

func (s *Server) handler(r *dns.Request) *dns.Reply {
	s.mu.RLock()
	defer s.mu.RUnlock()
	if !s.matcher.Match(nonFqdn(r.Name)) {
		return nil // No match
	}
	switch s.Config.Filter.RejectMode {
	case "zero":
		switch r.Type {
		case dns.TypeA:
			return dns.ReplyA(r.Name, net.IPv4zero)
		case dns.TypeAAAA:
			return dns.ReplyAAAA(r.Name, net.IPv6zero)
		}
	case "no-data":
		return &dns.Reply{}
	case "hosts":
		// TODO: Provide answer from hosts
	}
	return nil
}

// ListenAndServe listens on the network address addr and uses the server to process requests.
func (s *Server) ListenAndServe(addr, network string) error {
	return s.proxy.ListenAndServe(addr, network)
}