diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-12-22 16:21:23 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-12-22 18:01:40 +0100 |
commit | 1578925c557dab6b6c25af657771742c45dd9303 (patch) | |
tree | c15b3dace22fd7901ab2c2d0ed7eec2407668e11 /signal | |
parent | eff1bca2529f448b2bc6435dc56dc8c885065b54 (diff) |
Extract signal handling to separate package
Diffstat (limited to 'signal')
-rw-r--r-- | signal/signal.go | 60 | ||||
-rw-r--r-- | signal/signal_test.go | 61 |
2 files changed, 121 insertions, 0 deletions
diff --git a/signal/signal.go b/signal/signal.go new file mode 100644 index 0000000..96ef5f7 --- /dev/null +++ b/signal/signal.go @@ -0,0 +1,60 @@ +package signal + +import ( + "io" + "os" + "os/signal" + "syscall" + + "github.com/mpolden/zdns/log" +) + +// Reloader is the interface for types that need to act on a reload signal. +type Reloader interface { + Reload() +} + +// Handler represents a signal handler and holds references to types that should act on operating system signals. +type Handler struct { + logger *log.Logger + signal chan os.Signal + reloaders []Reloader + closers []io.Closer +} + +// NewHandler creates a new handler for handling operating system signals. +func NewHandler(c chan os.Signal, logger *log.Logger) *Handler { + h := &Handler{logger: logger, signal: c} + signal.Notify(h.signal) + go h.readSignal() + return h +} + +// OnReload registers a reloader to call for the signal SIGHUP. +func (s *Handler) OnReload(r Reloader) { s.reloaders = append(s.reloaders, r) } + +// OnClose registers a closer to call for signals SIGTERM and SIGINT. +func (s *Handler) OnClose(c io.Closer) { s.closers = append(s.closers, c) } + +func (s *Handler) readSignal() { + for sig := range s.signal { + switch sig { + case syscall.SIGHUP: + s.logger.Printf("received signal %s: reloading", sig) + for _, r := range s.reloaders { + r.Reload() + } + case syscall.SIGTERM, syscall.SIGINT: + signal.Stop(s.signal) + s.logger.Printf("received signal %s: shutting down", sig) + for _, c := range s.closers { + if err := c.Close(); err != nil { + s.logger.Printf("close failed: %s", err) + } + } + + default: + s.logger.Printf("received signal %s: ignoring", sig) + } + } +} diff --git a/signal/signal_test.go b/signal/signal_test.go new file mode 100644 index 0000000..544b9a0 --- /dev/null +++ b/signal/signal_test.go @@ -0,0 +1,61 @@ +package signal + +import ( + "io/ioutil" + "os" + "syscall" + "testing" + "time" + + "github.com/mpolden/zdns/log" +) + +type reloaderCloser struct { + reloaded bool + closed bool +} + +func (rc *reloaderCloser) Reload() { rc.reloaded = true } +func (rc *reloaderCloser) Close() error { + rc.closed = true + return nil +} +func (rc *reloaderCloser) isReloaded() bool { return rc.reloaded } +func (rc *reloaderCloser) isClosed() bool { return rc.closed } +func (rc *reloaderCloser) reset() { + rc.reloaded = false + rc.closed = false +} + +func TestHandler(t *testing.T) { + logger, err := log.New(ioutil.Discard, "", log.RecordOptions{}) + if err != nil { + t.Fatal(err) + } + h := NewHandler(make(chan os.Signal, 1), logger) + + rc := &reloaderCloser{} + h.OnReload(rc) + h.OnClose(rc) + + var tests = []struct { + signal syscall.Signal + value func() bool + }{ + {syscall.SIGHUP, rc.isReloaded}, + {syscall.SIGTERM, rc.isClosed}, + {syscall.SIGINT, rc.isClosed}, + } + + for _, tt := range tests { + rc.reset() + h.signal <- tt.signal + ts := time.Now() + for !tt.value() { + time.Sleep(10 * time.Millisecond) + if time.Since(ts) > 2*time.Second { + t.Fatalf("timed out waiting for handler of signal %s", tt.signal) + } + } + } +} |