aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2020-01-20 22:24:11 +0100
committerMartin Polden <mpolden@mpolden.no>2020-01-20 22:28:09 +0100
commit551febd6d6e1e63750d136e229397c12bf1ec54a (patch)
tree6e00311640b4e23a5457ffd5912f81cc6fb36787
parent8531d8a6653b3fffeee4d6d6699e9f7052f86422 (diff)
Wait for signal handlers to complete on exit
-rw-r--r--cmd/zdns/main.go4
-rw-r--r--cmd/zdns/main_test.go6
-rw-r--r--signal/signal.go25
3 files changed, 26 insertions, 9 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index b0537f9..5af70c6 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -29,6 +29,7 @@ type server interface{ ListenAndServe() error }
type cli struct {
servers []server
+ sh *signal.Handler
wg sync.WaitGroup
}
@@ -151,7 +152,7 @@ func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal)
// ... and finally the server itself
sigHandler.OnClose(dnsSrv)
- return &cli{servers: servers}, nil
+ return &cli{servers: servers, sh: sigHandler}, nil
}
func (c *cli) run() {
@@ -159,6 +160,7 @@ func (c *cli) run() {
c.runServer(s)
}
c.wg.Wait()
+ c.sh.Close()
}
func main() {
diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go
index da0bfcb..4b4d6ad 100644
--- a/cmd/zdns/main_test.go
+++ b/cmd/zdns/main_test.go
@@ -3,6 +3,7 @@ package main
import (
"io/ioutil"
"os"
+ "syscall"
"testing"
)
@@ -37,8 +38,11 @@ hijack_mode = "zero"
}
defer os.Remove(f)
- _, err = newCli(ioutil.Discard, []string{"-f", f}, f, make(chan os.Signal, 1))
+ sig := make(chan os.Signal, 1)
+ cli, err := newCli(ioutil.Discard, []string{"-f", f}, f, sig)
if err != nil {
t.Fatal(err)
}
+ sig <- syscall.SIGTERM
+ cli.sh.Close()
}
diff --git a/signal/signal.go b/signal/signal.go
index 99b2e6f..a6a04b5 100644
--- a/signal/signal.go
+++ b/signal/signal.go
@@ -5,6 +5,7 @@ import (
"log"
"os"
"os/signal"
+ "sync"
"syscall"
)
@@ -18,34 +19,44 @@ type Handler struct {
signal chan os.Signal
reloaders []Reloader
closers []io.Closer
+ wg sync.WaitGroup
}
// NewHandler creates a new handler for handling operating system signals.
func NewHandler(c chan os.Signal) *Handler {
h := &Handler{signal: c}
signal.Notify(h.signal)
+ h.wg.Add(1)
go h.readSignal()
return h
}
// OnReload registers a reloader to call for the signal SIGHUP.
-func (s *Handler) OnReload(r Reloader) { s.reloaders = append(s.reloaders, r) }
+func (h *Handler) OnReload(r Reloader) { h.reloaders = append(h.reloaders, r) }
// OnClose registers a closer to call for signals SIGTERM and SIGINT.
-func (s *Handler) OnClose(c io.Closer) { s.closers = append(s.closers, c) }
+func (h *Handler) OnClose(c io.Closer) { h.closers = append(h.closers, c) }
-func (s *Handler) readSignal() {
- for sig := range s.signal {
+// Close stops handling any new signals and completes processing of pending signals before returning.
+func (h *Handler) Close() error {
+ signal.Stop(h.signal)
+ close(h.signal)
+ h.wg.Wait()
+ return nil
+}
+
+func (h *Handler) readSignal() {
+ defer h.wg.Done()
+ for sig := range h.signal {
switch sig {
case syscall.SIGHUP:
log.Printf("received signal %s: reloading", sig)
- for _, r := range s.reloaders {
+ for _, r := range h.reloaders {
r.Reload()
}
case syscall.SIGTERM, syscall.SIGINT:
- signal.Stop(s.signal)
log.Printf("received signal %s: shutting down", sig)
- for _, c := range s.closers {
+ for _, c := range h.closers {
if err := c.Close(); err != nil {
log.Printf("close of %T failed: %s", c, err)
}