aboutsummaryrefslogtreecommitdiffstats
path: root/cmd
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-22 16:21:23 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-22 18:01:40 +0100
commit1578925c557dab6b6c25af657771742c45dd9303 (patch)
treec15b3dace22fd7901ab2c2d0ed7eec2407668e11 /cmd
parenteff1bca2529f448b2bc6435dc56dc8c885065b54 (diff)
Extract signal handling to separate package
Diffstat (limited to 'cmd')
-rw-r--r--cmd/zdns/main.go91
-rw-r--r--cmd/zdns/main_test.go24
2 files changed, 81 insertions, 34 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index 91db1be..5565ef4 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -5,11 +5,13 @@ import (
"io"
"os"
"path/filepath"
+ "sync"
"flag"
"github.com/mpolden/zdns"
"github.com/mpolden/zdns/log"
+ "github.com/mpolden/zdns/signal"
)
const (
@@ -18,46 +20,81 @@ const (
configName = "." + name + "rc"
)
+type server interface{ ListenAndServe() error }
+
+type cli struct {
+ wg sync.WaitGroup
+ configFile string
+ out io.Writer
+ args []string
+ signal chan os.Signal
+}
+
func defaultConfigFile() string { return filepath.Join(os.Getenv("HOME"), configName) }
-func newServer(out io.Writer, confFile string) (*zdns.Server, error) {
- f, err := os.Open(confFile)
- if err != nil {
- return nil, err
- }
- defer func() { _ = f.Close() }()
- conf, err := zdns.ReadConfig(f)
- if err != nil {
- return nil, err
- }
- log, err := log.New(out, logPrefix, log.RecordOptions{
- Database: conf.DNS.LogDatabase,
- TTL: conf.DNS.LogTTL,
- })
+func readConfig(file string) (zdns.Config, error) {
+ f, err := os.Open(file)
if err != nil {
- return nil, err
+ return zdns.Config{}, err
}
- return zdns.NewServer(log, conf)
+ defer f.Close()
+ return zdns.ReadConfig(f)
}
func fatal(err error) {
+ if err == nil {
+ return
+ }
fmt.Fprintf(os.Stderr, "%s: %s\n", logPrefix, err)
os.Exit(1)
}
-func main() {
- confFile := flag.String("f", defaultConfigFile(), "config file `path`")
- help := flag.Bool("h", false, "print usage")
- flag.Parse()
+func (c *cli) runServer(server server) {
+ c.wg.Add(1)
+ go func() {
+ defer c.wg.Done()
+ if err := server.ListenAndServe(); err != nil {
+ fatal(err)
+ }
+ }()
+}
+
+func (c *cli) run() {
+ f := flag.CommandLine
+ f.SetOutput(c.out)
+ confFile := f.String("f", c.configFile, "config file `path`")
+ help := f.Bool("h", false, "print usage")
+ f.Parse(c.args)
if *help {
- flag.Usage()
+ f.Usage()
return
}
- srv, err := newServer(os.Stderr, *confFile)
- if err != nil {
- fatal(err)
- }
- if err := srv.ListenAndServe(); err != nil {
- fatal(err)
+
+ config, err := readConfig(*confFile)
+ fatal(err)
+
+ logger, err := log.New(c.out, logPrefix, log.RecordOptions{
+ Database: config.DNS.LogDatabase,
+ TTL: config.DNS.LogTTL,
+ })
+ fatal(err)
+
+ dnsSrv, err := zdns.NewServer(logger, config)
+ fatal(err)
+
+ sigHandler := signal.NewHandler(c.signal, logger)
+ sigHandler.OnReload(dnsSrv)
+ sigHandler.OnClose(dnsSrv)
+ c.runServer(dnsSrv)
+ c.wg.Wait()
+}
+
+func main() {
+ c := cli{
+ out: os.Stderr,
+ configFile: defaultConfigFile(),
+ args: os.Args[1:],
+ signal: make(chan os.Signal, 1),
}
+ c.run()
}
diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go
index 851468b..f2dfa25 100644
--- a/cmd/zdns/main_test.go
+++ b/cmd/zdns/main_test.go
@@ -3,6 +3,8 @@ package main
import (
"io/ioutil"
"os"
+ "sync"
+ "syscall"
"testing"
)
@@ -40,12 +42,20 @@ hijack_mode = "zero"
if err != nil {
t.Fatal(err)
}
- defer handleErr(t, func() error { return os.Remove(f) })
- srv, err := newServer(ioutil.Discard, f)
- if err != nil {
- t.Fatal(err)
- }
- if srv == nil {
- t.Error("want non-nil server")
+ defer os.Remove(f)
+
+ main := cli{
+ out: ioutil.Discard,
+ configFile: f,
+ args: []string{"-f", f},
+ signal: make(chan os.Signal, 1),
}
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ main.run()
+ }()
+ main.signal <- syscall.SIGTERM
+ wg.Wait()
}