aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-30 18:41:11 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-30 18:41:11 +0100
commit908c837c9de142ebbbdfe125d91a78ede6730b6c (patch)
tree7b0e94021a6327e148d48f069cf02fec55db4eb4
parent959641dc2eef343afe8a563fec895a68e1d6c6b8 (diff)
Simplify hosts reloading
-rw-r--r--server.go15
-rw-r--r--server_test.go9
2 files changed, 6 insertions, 18 deletions
diff --git a/server.go b/server.go
index 2e17997..25c48ef 100644
--- a/server.go
+++ b/server.go
@@ -31,7 +31,6 @@ type Server struct {
hosts hosts.Hosts
logger *log.Logger
proxy *dns.Proxy
- ticker *time.Ticker
done chan bool
mu sync.RWMutex
httpClient *http.Client
@@ -50,8 +49,7 @@ func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, er
// Periodically refresh hosts
if t := config.DNS.refreshInterval; t > 0 {
- server.ticker = time.NewTicker(t)
- go server.reloadHosts()
+ go server.reloadHosts(config.DNS.refreshInterval)
}
// Load initial hosts
@@ -113,13 +111,12 @@ func nonFqdn(s string) string {
return s
}
-func (s *Server) reloadHosts() {
+func (s *Server) reloadHosts(interval time.Duration) {
for {
select {
case <-s.done:
- s.ticker.Stop()
return
- case <-s.ticker.C:
+ case <-time.After(interval):
s.loadHosts()
}
}
@@ -163,14 +160,12 @@ func (s *Server) loadHosts() {
s.logger.Printf("loaded %d hosts in total", len(hs))
}
-// Reload reloads the configuration of this server.
+// Reload updates hosts entries of Server s.
func (s *Server) Reload() { s.loadHosts() }
// Close terminates all active operations and shuts down the DNS server.
func (s *Server) Close() error {
- if s.ticker != nil {
- s.done <- true
- }
+ s.done <- true
return nil
}
diff --git a/server_test.go b/server_test.go
index 17883e3..9b0aaf0 100644
--- a/server_test.go
+++ b/server_test.go
@@ -29,12 +29,6 @@ const hostsFile2 = `
192.0.2.6 badhost6
`
-func handleErr(t *testing.T, fn func() error) {
- if err := fn(); err != nil {
- t.Fatal(err)
- }
-}
-
func httpHandler(t *testing.T, response string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte(response)); err != nil {
@@ -52,7 +46,7 @@ func tempFile(t *testing.T, s string) (string, error) {
if err != nil {
return "", err
}
- defer handleErr(t, f.Close)
+ defer f.Close()
if err := ioutil.WriteFile(f.Name(), []byte(s), 0644); err != nil {
return "", err
}
@@ -178,7 +172,6 @@ func TestHijack(t *testing.T) {
},
logger: log,
}
- defer handleErr(t, s.Close)
var tests = []struct {
rtype uint16