diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-12-30 21:57:20 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-12-30 21:57:37 +0100 |
commit | 50159f241973d24e48b2d2c3ae5e353d9c64eb4d (patch) | |
tree | bffbc7733ff389c3e173cdbe1cf1e4eedac1a322 | |
parent | 79d65083ead62220e40138f1a45197b78c498d63 (diff) |
Stabilize TestMain
-rw-r--r-- | cmd/zdns/main.go | 43 | ||||
-rw-r--r-- | cmd/zdns/main_test.go | 20 |
2 files changed, 35 insertions, 28 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index cda841d..cd01e9c 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -27,11 +27,9 @@ const ( type server interface{ ListenAndServe() error } type cli struct { - wg sync.WaitGroup - configFile string - out io.Writer - args []string - signal chan os.Signal + servers []server + started int + wg sync.WaitGroup } func defaultConfigFile() string { return filepath.Join(os.Getenv("HOME"), configName) } @@ -61,17 +59,18 @@ func (c *cli) runServer(server server) { fatal(err) } }() + c.started++ } -func (c *cli) run() { +func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) *cli { f := flag.CommandLine - f.SetOutput(c.out) - confFile := f.String("f", c.configFile, "config file `path`") + f.SetOutput(out) + confFile := f.String("f", configFile, "config file `path`") help := f.Bool("h", false, "print usage") - f.Parse(c.args) + f.Parse(args) if *help { f.Usage() - return + return nil } // Config @@ -79,7 +78,7 @@ func (c *cli) run() { fatal(err) // Logger - logger, err := log.New(c.out, logPrefix, log.RecordOptions{ + logger, err := log.New(out, logPrefix, log.RecordOptions{ Mode: config.DNS.LogMode, Database: config.DNS.LogDatabase, TTL: config.DNS.LogTTL, @@ -87,7 +86,7 @@ func (c *cli) run() { fatal(err) // Signal handling - sigHandler := signal.NewHandler(c.signal, logger) + sigHandler := signal.NewHandler(sig, logger) sigHandler.OnClose(logger) // Client @@ -109,23 +108,27 @@ func (c *cli) run() { fatal(err) sigHandler.OnReload(dnsSrv) sigHandler.OnClose(dnsSrv) - c.runServer(dnsSrv) + servers := []server{dnsSrv} // HTTP server if config.DNS.ListenHTTP != "" { httpSrv := http.NewServer(logger, cache, config.DNS.ListenHTTP) sigHandler.OnClose(httpSrv) - c.runServer(httpSrv) + servers = append(servers, httpSrv) + } + return &cli{servers: servers} +} + +func (c *cli) run() { + for _, srv := range c.servers { + c.runServer(srv) } c.wg.Wait() } func main() { - c := cli{ - out: os.Stderr, - configFile: defaultConfigFile(), - args: os.Args[1:], - signal: make(chan os.Signal, 1), + c := newCli(os.Stderr, os.Args[1:], defaultConfigFile(), make(chan os.Signal, 1)) + if c != nil { + c.run() } - c.run() } diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go index ee24b5d..647e8cf 100644 --- a/cmd/zdns/main_test.go +++ b/cmd/zdns/main_test.go @@ -6,6 +6,7 @@ import ( "sync" "syscall" "testing" + "time" ) func tempFile(t *testing.T, s string) (string, error) { @@ -39,18 +40,21 @@ hijack_mode = "zero" } defer os.Remove(f) - main := cli{ - out: ioutil.Discard, - configFile: f, - args: []string{"-f", f}, - signal: make(chan os.Signal, 1), - } + sig := make(chan os.Signal, 1) + c := newCli(os.Stderr, []string{"-f", f}, f, sig) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - main.run() + c.run() }() - main.signal <- syscall.SIGTERM + ts := time.Now() + for c.started < 2 { + time.Sleep(10 * time.Millisecond) // Wait for servers to start + if time.Since(ts) > 2*time.Second { + t.Fatal("timed out waiting for servers to start") + } + } + sig <- syscall.SIGTERM wg.Wait() } |