From 1578925c557dab6b6c25af657771742c45dd9303 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sun, 22 Dec 2019 16:21:23 +0100 Subject: Extract signal handling to separate package --- cmd/zdns/main.go | 91 ++++++++++++++++++++++++++++++++++++--------------- cmd/zdns/main_test.go | 24 ++++++++++---- 2 files changed, 81 insertions(+), 34 deletions(-) (limited to 'cmd') diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 91db1be..5565ef4 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -5,11 +5,13 @@ import ( "io" "os" "path/filepath" + "sync" "flag" "github.com/mpolden/zdns" "github.com/mpolden/zdns/log" + "github.com/mpolden/zdns/signal" ) const ( @@ -18,46 +20,81 @@ const ( configName = "." + name + "rc" ) +type server interface{ ListenAndServe() error } + +type cli struct { + wg sync.WaitGroup + configFile string + out io.Writer + args []string + signal chan os.Signal +} + func defaultConfigFile() string { return filepath.Join(os.Getenv("HOME"), configName) } -func newServer(out io.Writer, confFile string) (*zdns.Server, error) { - f, err := os.Open(confFile) - if err != nil { - return nil, err - } - defer func() { _ = f.Close() }() - conf, err := zdns.ReadConfig(f) - if err != nil { - return nil, err - } - log, err := log.New(out, logPrefix, log.RecordOptions{ - Database: conf.DNS.LogDatabase, - TTL: conf.DNS.LogTTL, - }) +func readConfig(file string) (zdns.Config, error) { + f, err := os.Open(file) if err != nil { - return nil, err + return zdns.Config{}, err } - return zdns.NewServer(log, conf) + defer f.Close() + return zdns.ReadConfig(f) } func fatal(err error) { + if err == nil { + return + } fmt.Fprintf(os.Stderr, "%s: %s\n", logPrefix, err) os.Exit(1) } -func main() { - confFile := flag.String("f", defaultConfigFile(), "config file `path`") - help := flag.Bool("h", false, "print usage") - flag.Parse() +func (c *cli) runServer(server server) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + if err := server.ListenAndServe(); err != nil { + fatal(err) + } + }() +} + +func (c *cli) run() { + f := flag.CommandLine + f.SetOutput(c.out) + confFile := f.String("f", c.configFile, "config file `path`") + help := f.Bool("h", false, "print usage") + f.Parse(c.args) if *help { - flag.Usage() + f.Usage() return } - srv, err := newServer(os.Stderr, *confFile) - if err != nil { - fatal(err) - } - if err := srv.ListenAndServe(); err != nil { - fatal(err) + + config, err := readConfig(*confFile) + fatal(err) + + logger, err := log.New(c.out, logPrefix, log.RecordOptions{ + Database: config.DNS.LogDatabase, + TTL: config.DNS.LogTTL, + }) + fatal(err) + + dnsSrv, err := zdns.NewServer(logger, config) + fatal(err) + + sigHandler := signal.NewHandler(c.signal, logger) + sigHandler.OnReload(dnsSrv) + sigHandler.OnClose(dnsSrv) + c.runServer(dnsSrv) + c.wg.Wait() +} + +func main() { + c := cli{ + out: os.Stderr, + configFile: defaultConfigFile(), + args: os.Args[1:], + signal: make(chan os.Signal, 1), } + c.run() } diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go index 851468b..f2dfa25 100644 --- a/cmd/zdns/main_test.go +++ b/cmd/zdns/main_test.go @@ -3,6 +3,8 @@ package main import ( "io/ioutil" "os" + "sync" + "syscall" "testing" ) @@ -40,12 +42,20 @@ hijack_mode = "zero" if err != nil { t.Fatal(err) } - defer handleErr(t, func() error { return os.Remove(f) }) - srv, err := newServer(ioutil.Discard, f) - if err != nil { - t.Fatal(err) - } - if srv == nil { - t.Error("want non-nil server") + defer os.Remove(f) + + main := cli{ + out: ioutil.Discard, + configFile: f, + args: []string{"-f", f}, + signal: make(chan os.Signal, 1), } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + main.run() + }() + main.signal <- syscall.SIGTERM + wg.Wait() } -- cgit v1.2.3