diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-12-30 18:41:11 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-12-30 18:41:11 +0100 |
commit | 908c837c9de142ebbbdfe125d91a78ede6730b6c (patch) | |
tree | 7b0e94021a6327e148d48f069cf02fec55db4eb4 | |
parent | 959641dc2eef343afe8a563fec895a68e1d6c6b8 (diff) |
Simplify hosts reloading
-rw-r--r-- | server.go | 15 | ||||
-rw-r--r-- | server_test.go | 9 |
2 files changed, 6 insertions, 18 deletions
@@ -31,7 +31,6 @@ type Server struct { hosts hosts.Hosts logger *log.Logger proxy *dns.Proxy - ticker *time.Ticker done chan bool mu sync.RWMutex httpClient *http.Client @@ -50,8 +49,7 @@ func NewServer(logger *log.Logger, proxy *dns.Proxy, config Config) (*Server, er // Periodically refresh hosts if t := config.DNS.refreshInterval; t > 0 { - server.ticker = time.NewTicker(t) - go server.reloadHosts() + go server.reloadHosts(config.DNS.refreshInterval) } // Load initial hosts @@ -113,13 +111,12 @@ func nonFqdn(s string) string { return s } -func (s *Server) reloadHosts() { +func (s *Server) reloadHosts(interval time.Duration) { for { select { case <-s.done: - s.ticker.Stop() return - case <-s.ticker.C: + case <-time.After(interval): s.loadHosts() } } @@ -163,14 +160,12 @@ func (s *Server) loadHosts() { s.logger.Printf("loaded %d hosts in total", len(hs)) } -// Reload reloads the configuration of this server. +// Reload updates hosts entries of Server s. func (s *Server) Reload() { s.loadHosts() } // Close terminates all active operations and shuts down the DNS server. func (s *Server) Close() error { - if s.ticker != nil { - s.done <- true - } + s.done <- true return nil } diff --git a/server_test.go b/server_test.go index 17883e3..9b0aaf0 100644 --- a/server_test.go +++ b/server_test.go @@ -29,12 +29,6 @@ const hostsFile2 = ` 192.0.2.6 badhost6 ` -func handleErr(t *testing.T, fn func() error) { - if err := fn(); err != nil { - t.Fatal(err) - } -} - func httpHandler(t *testing.T, response string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(response)); err != nil { @@ -52,7 +46,7 @@ func tempFile(t *testing.T, s string) (string, error) { if err != nil { return "", err } - defer handleErr(t, f.Close) + defer f.Close() if err := ioutil.WriteFile(f.Name(), []byte(s), 0644); err != nil { return "", err } @@ -178,7 +172,6 @@ func TestHijack(t *testing.T) { }, logger: log, } - defer handleErr(t, s.Close) var tests = []struct { rtype uint16 |