diff options
author | Martin Polden <mpolden@mpolden.no> | 2020-01-20 22:24:11 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2020-01-20 22:28:09 +0100 |
commit | 551febd6d6e1e63750d136e229397c12bf1ec54a (patch) | |
tree | 6e00311640b4e23a5457ffd5912f81cc6fb36787 | |
parent | 8531d8a6653b3fffeee4d6d6699e9f7052f86422 (diff) |
Wait for signal handlers to complete on exit
-rw-r--r-- | cmd/zdns/main.go | 4 | ||||
-rw-r--r-- | cmd/zdns/main_test.go | 6 | ||||
-rw-r--r-- | signal/signal.go | 25 |
3 files changed, 26 insertions, 9 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index b0537f9..5af70c6 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -29,6 +29,7 @@ type server interface{ ListenAndServe() error } type cli struct { servers []server + sh *signal.Handler wg sync.WaitGroup } @@ -151,7 +152,7 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) // ... and finally the server itself sigHandler.OnClose(dnsSrv) - return &cli{servers: servers}, nil + return &cli{servers: servers, sh: sigHandler}, nil } func (c *cli) run() { @@ -159,6 +160,7 @@ func (c *cli) run() { c.runServer(s) } c.wg.Wait() + c.sh.Close() } func main() { diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go index da0bfcb..4b4d6ad 100644 --- a/cmd/zdns/main_test.go +++ b/cmd/zdns/main_test.go @@ -3,6 +3,7 @@ package main import ( "io/ioutil" "os" + "syscall" "testing" ) @@ -37,8 +38,11 @@ hijack_mode = "zero" } defer os.Remove(f) - _, err = newCli(ioutil.Discard, []string{"-f", f}, f, make(chan os.Signal, 1)) + sig := make(chan os.Signal, 1) + cli, err := newCli(ioutil.Discard, []string{"-f", f}, f, sig) if err != nil { t.Fatal(err) } + sig <- syscall.SIGTERM + cli.sh.Close() } diff --git a/signal/signal.go b/signal/signal.go index 99b2e6f..a6a04b5 100644 --- a/signal/signal.go +++ b/signal/signal.go @@ -5,6 +5,7 @@ import ( "log" "os" "os/signal" + "sync" "syscall" ) @@ -18,34 +19,44 @@ type Handler struct { signal chan os.Signal reloaders []Reloader closers []io.Closer + wg sync.WaitGroup } // NewHandler creates a new handler for handling operating system signals. func NewHandler(c chan os.Signal) *Handler { h := &Handler{signal: c} signal.Notify(h.signal) + h.wg.Add(1) go h.readSignal() return h } // OnReload registers a reloader to call for the signal SIGHUP. -func (s *Handler) OnReload(r Reloader) { s.reloaders = append(s.reloaders, r) } +func (h *Handler) OnReload(r Reloader) { h.reloaders = append(h.reloaders, r) } // OnClose registers a closer to call for signals SIGTERM and SIGINT. -func (s *Handler) OnClose(c io.Closer) { s.closers = append(s.closers, c) } +func (h *Handler) OnClose(c io.Closer) { h.closers = append(h.closers, c) } -func (s *Handler) readSignal() { - for sig := range s.signal { +// Close stops handling any new signals and completes processing of pending signals before returning. +func (h *Handler) Close() error { + signal.Stop(h.signal) + close(h.signal) + h.wg.Wait() + return nil +} + +func (h *Handler) readSignal() { + defer h.wg.Done() + for sig := range h.signal { switch sig { case syscall.SIGHUP: log.Printf("received signal %s: reloading", sig) - for _, r := range s.reloaders { + for _, r := range h.reloaders { r.Reload() } case syscall.SIGTERM, syscall.SIGINT: - signal.Stop(s.signal) log.Printf("received signal %s: shutting down", sig) - for _, c := range s.closers { + for _, c := range h.closers { if err := c.Close(); err != nil { log.Printf("close of %T failed: %s", c, err) } |