aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-30 21:57:20 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-30 21:57:37 +0100
commit50159f241973d24e48b2d2c3ae5e353d9c64eb4d (patch)
treebffbc7733ff389c3e173cdbe1cf1e4eedac1a322
parent79d65083ead62220e40138f1a45197b78c498d63 (diff)
Stabilize TestMain
-rw-r--r--cmd/zdns/main.go43
-rw-r--r--cmd/zdns/main_test.go20
2 files changed, 35 insertions, 28 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index cda841d..cd01e9c 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -27,11 +27,9 @@ const (
type server interface{ ListenAndServe() error }
type cli struct {
- wg sync.WaitGroup
- configFile string
- out io.Writer
- args []string
- signal chan os.Signal
+ servers []server
+ started int
+ wg sync.WaitGroup
}
func defaultConfigFile() string { return filepath.Join(os.Getenv("HOME"), configName) }
@@ -61,17 +59,18 @@ func (c *cli) runServer(server server) {
fatal(err)
}
}()
+ c.started++
}
-func (c *cli) run() {
+func newCli(out io.Writer, args []string, configFile string, sig chan os.Signal) *cli {
f := flag.CommandLine
- f.SetOutput(c.out)
- confFile := f.String("f", c.configFile, "config file `path`")
+ f.SetOutput(out)
+ confFile := f.String("f", configFile, "config file `path`")
help := f.Bool("h", false, "print usage")
- f.Parse(c.args)
+ f.Parse(args)
if *help {
f.Usage()
- return
+ return nil
}
// Config
@@ -79,7 +78,7 @@ func (c *cli) run() {
fatal(err)
// Logger
- logger, err := log.New(c.out, logPrefix, log.RecordOptions{
+ logger, err := log.New(out, logPrefix, log.RecordOptions{
Mode: config.DNS.LogMode,
Database: config.DNS.LogDatabase,
TTL: config.DNS.LogTTL,
@@ -87,7 +86,7 @@ func (c *cli) run() {
fatal(err)
// Signal handling
- sigHandler := signal.NewHandler(c.signal, logger)
+ sigHandler := signal.NewHandler(sig, logger)
sigHandler.OnClose(logger)
// Client
@@ -109,23 +108,27 @@ func (c *cli) run() {
fatal(err)
sigHandler.OnReload(dnsSrv)
sigHandler.OnClose(dnsSrv)
- c.runServer(dnsSrv)
+ servers := []server{dnsSrv}
// HTTP server
if config.DNS.ListenHTTP != "" {
httpSrv := http.NewServer(logger, cache, config.DNS.ListenHTTP)
sigHandler.OnClose(httpSrv)
- c.runServer(httpSrv)
+ servers = append(servers, httpSrv)
+ }
+ return &cli{servers: servers}
+}
+
+func (c *cli) run() {
+ for _, srv := range c.servers {
+ c.runServer(srv)
}
c.wg.Wait()
}
func main() {
- c := cli{
- out: os.Stderr,
- configFile: defaultConfigFile(),
- args: os.Args[1:],
- signal: make(chan os.Signal, 1),
+ c := newCli(os.Stderr, os.Args[1:], defaultConfigFile(), make(chan os.Signal, 1))
+ if c != nil {
+ c.run()
}
- c.run()
}
diff --git a/cmd/zdns/main_test.go b/cmd/zdns/main_test.go
index ee24b5d..647e8cf 100644
--- a/cmd/zdns/main_test.go
+++ b/cmd/zdns/main_test.go
@@ -6,6 +6,7 @@ import (
"sync"
"syscall"
"testing"
+ "time"
)
func tempFile(t *testing.T, s string) (string, error) {
@@ -39,18 +40,21 @@ hijack_mode = "zero"
}
defer os.Remove(f)
- main := cli{
- out: ioutil.Discard,
- configFile: f,
- args: []string{"-f", f},
- signal: make(chan os.Signal, 1),
- }
+ sig := make(chan os.Signal, 1)
+ c := newCli(os.Stderr, []string{"-f", f}, f, sig)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
- main.run()
+ c.run()
}()
- main.signal <- syscall.SIGTERM
+ ts := time.Now()
+ for c.started < 2 {
+ time.Sleep(10 * time.Millisecond) // Wait for servers to start
+ if time.Since(ts) > 2*time.Second {
+ t.Fatal("timed out waiting for servers to start")
+ }
+ }
+ sig <- syscall.SIGTERM
wg.Wait()
}