From c591ee76bcad95345d6302a12302a9e4f090005d Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Tue, 31 Dec 2019 14:01:18 +0100 Subject: Clean up racy test code --- cache/cache_test.go | 19 +++++++++++++++++-- dns/proxy_test.go | 36 ++++++++++++++++++++++++++---------- signal/signal_test.go | 25 ++++++++++++++++++++++--- 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/cache/cache_test.go b/cache/cache_test.go index 541ae10..b00e5d9 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "reflect" + "sync" "testing" "time" @@ -12,10 +13,19 @@ import ( ) type testExchanger struct { + mu sync.RWMutex answer *dns.Msg } +func (e *testExchanger) setAnswer(answer *dns.Msg) { + e.mu.Lock() + defer e.mu.Unlock() + e.answer = answer +} + func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { + e.mu.RLock() + defer e.mu.RUnlock() return e.answer, time.Second, nil } @@ -61,10 +71,13 @@ func awaitExpiry(t *testing.T, i int, c *Cache, k uint64) { func awaitRefresh(t *testing.T, i int, c *Cache, k uint64, u time.Time) { now := time.Now() for { // Loop until CreatedAt of key k is after u - v, ok := c.getValue(k) + c.mu.RLock() + v, ok := c.values[k] if ok && v.CreatedAt.After(u) { + c.mu.RUnlock() break } + c.mu.RUnlock() time.Sleep(10 * time.Millisecond) if time.Since(now) > 2*time.Second { t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k) @@ -132,6 +145,7 @@ func TestCache(t *testing.T) { if !tt.ok { awaitExpiry(t, i, c, k) } + c.mu.RLock() if _, ok := c.values[k]; ok != tt.ok { t.Errorf("#%d: values[%d] = %t, want %t", i, k, ok, tt.ok) } @@ -142,6 +156,7 @@ func TestCache(t *testing.T) { break } } + c.mu.RUnlock() if (keyIdx != -1) != tt.ok { t.Errorf("#%d: keys[%d] = %d, found expired key", i, keyIdx, k) } @@ -266,7 +281,7 @@ func TestCachePrefetch(t *testing.T) { copy := msg.Copy() copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer) copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds()) - exchanger.answer = copy + exchanger.setAnswer(copy) c.now = func() time.Time { return now } var key uint64 = 1 diff --git a/dns/proxy_test.go b/dns/proxy_test.go index f9783bf..7d2a894 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net" "reflect" + "sync" "testing" "time" @@ -36,10 +37,23 @@ type response struct { fail bool } -type testExchanger map[string]*response +type testExchanger struct { + mu sync.RWMutex + responses map[string]*response +} + +func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} } + +func (e *testExchanger) setResponse(resolver string, response *response) { + e.mu.Lock() + defer e.mu.Unlock() + e.responses[resolver] = response +} -func (e testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { - r, ok := e[addr] +func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { + e.mu.RLock() + defer e.mu.RUnlock() + r, ok := e.responses[addr] if !ok { panic("no such resolver: " + addr) } @@ -148,7 +162,7 @@ func TestProxy(t *testing.T) { func TestProxyWithResolvers(t *testing.T) { p := testProxy(t) - exchanger := make(testExchanger) + exchanger := newTestExchanger() p.client = &dnsutil.Client{Exchanger: exchanger} defer p.Close() // No resolvers @@ -161,11 +175,12 @@ func TestProxyWithResolvers(t *testing.T) { m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - exchanger["resolver1"] = &response{answer: &m} + response1 := &response{answer: &m} + exchanger.setResponse("resolver1", response1) assertRR(t, p, &m, "192.0.2.1") // First and only resolver fails - exchanger["resolver1"].fail = true + response1.fail = true assertFailure(t, p, TypeA, "host1") // First resolver fails, but second succeeds @@ -175,18 +190,19 @@ func TestProxyWithResolvers(t *testing.T) { m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - exchanger["resolver2"] = &response{answer: &m} + response2 := &response{answer: &m} + exchanger.setResponse("resolver2", response2) assertRR(t, p, &m, "192.0.2.2") // All resolvers fail - exchanger["resolver2"].fail = true + response2.fail = true assertFailure(t, p, TypeA, "host1") } func TestProxyWithCache(t *testing.T) { p := testProxy(t) p.cache = cache.New(10, nil) - exchanger := make(testExchanger) + exchanger := newTestExchanger() p.client = &dnsutil.Client{Exchanger: exchanger} p.client.Addresses = []string{"resolver1"} defer p.Close() @@ -196,7 +212,7 @@ func TestProxyWithCache(t *testing.T) { m.Id = dns.Id() m.SetQuestion("host1.", dns.TypeA) m.Answer = reply.rr - exchanger["resolver1"] = &response{answer: &m} + exchanger.setResponse("resolver1", &response{answer: &m}) assertRR(t, p, &m, "192.0.2.1") k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET) diff --git a/signal/signal_test.go b/signal/signal_test.go index 544b9a0..7010ebd 100644 --- a/signal/signal_test.go +++ b/signal/signal_test.go @@ -3,6 +3,7 @@ package signal import ( "io/ioutil" "os" + "sync" "syscall" "testing" "time" @@ -11,18 +12,36 @@ import ( ) type reloaderCloser struct { + mu sync.RWMutex reloaded bool closed bool } -func (rc *reloaderCloser) Reload() { rc.reloaded = true } +func (rc *reloaderCloser) Reload() { + rc.mu.Lock() + defer rc.mu.Unlock() + rc.reloaded = true +} func (rc *reloaderCloser) Close() error { + rc.mu.Lock() + defer rc.mu.Unlock() rc.closed = true return nil } -func (rc *reloaderCloser) isReloaded() bool { return rc.reloaded } -func (rc *reloaderCloser) isClosed() bool { return rc.closed } +func (rc *reloaderCloser) isReloaded() bool { + rc.mu.RLock() + defer rc.mu.RUnlock() + return rc.reloaded +} +func (rc *reloaderCloser) isClosed() bool { + rc.mu.RLock() + defer rc.mu.RUnlock() + return rc.closed +} + func (rc *reloaderCloser) reset() { + rc.mu.Lock() + defer rc.mu.Unlock() rc.reloaded = false rc.closed = false } -- cgit v1.2.3