aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-31 14:01:18 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-31 14:40:17 +0100
commitc591ee76bcad95345d6302a12302a9e4f090005d (patch)
tree0dad51a82570bf8bd7d29fa0e867db63d099f4f0
parentbf6a8dc976bd100030240935e4047bb58008a02b (diff)
Clean up racy test code
-rw-r--r--cache/cache_test.go19
-rw-r--r--dns/proxy_test.go36
-rw-r--r--signal/signal_test.go25
3 files changed, 65 insertions, 15 deletions
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 541ae10..b00e5d9 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"reflect"
+ "sync"
"testing"
"time"
@@ -12,10 +13,19 @@ import (
)
type testExchanger struct {
+ mu sync.RWMutex
answer *dns.Msg
}
+func (e *testExchanger) setAnswer(answer *dns.Msg) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.answer = answer
+}
+
func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
return e.answer, time.Second, nil
}
@@ -61,10 +71,13 @@ func awaitExpiry(t *testing.T, i int, c *Cache, k uint64) {
func awaitRefresh(t *testing.T, i int, c *Cache, k uint64, u time.Time) {
now := time.Now()
for { // Loop until CreatedAt of key k is after u
- v, ok := c.getValue(k)
+ c.mu.RLock()
+ v, ok := c.values[k]
if ok && v.CreatedAt.After(u) {
+ c.mu.RUnlock()
break
}
+ c.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
t.Fatalf("#%d: timed out waiting for refresh of key %d", i, k)
@@ -132,6 +145,7 @@ func TestCache(t *testing.T) {
if !tt.ok {
awaitExpiry(t, i, c, k)
}
+ c.mu.RLock()
if _, ok := c.values[k]; ok != tt.ok {
t.Errorf("#%d: values[%d] = %t, want %t", i, k, ok, tt.ok)
}
@@ -142,6 +156,7 @@ func TestCache(t *testing.T) {
break
}
}
+ c.mu.RUnlock()
if (keyIdx != -1) != tt.ok {
t.Errorf("#%d: keys[%d] = %d, found expired key", i, keyIdx, k)
}
@@ -266,7 +281,7 @@ func TestCachePrefetch(t *testing.T) {
copy := msg.Copy()
copy.Answer[0].(*dns.A).A = net.ParseIP(tt.refreshAnswer)
copy.Answer[0].(*dns.A).Hdr.Ttl = uint32(tt.refreshTTL.Seconds())
- exchanger.answer = copy
+ exchanger.setAnswer(copy)
c.now = func() time.Time { return now }
var key uint64 = 1
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index f9783bf..7d2a894 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -5,6 +5,7 @@ import (
"io/ioutil"
"net"
"reflect"
+ "sync"
"testing"
"time"
@@ -36,10 +37,23 @@ type response struct {
fail bool
}
-type testExchanger map[string]*response
+type testExchanger struct {
+ mu sync.RWMutex
+ responses map[string]*response
+}
+
+func newTestExchanger() *testExchanger { return &testExchanger{responses: make(map[string]*response)} }
+
+func (e *testExchanger) setResponse(resolver string, response *response) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.responses[resolver] = response
+}
-func (e testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
- r, ok := e[addr]
+func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ r, ok := e.responses[addr]
if !ok {
panic("no such resolver: " + addr)
}
@@ -148,7 +162,7 @@ func TestProxy(t *testing.T) {
func TestProxyWithResolvers(t *testing.T) {
p := testProxy(t)
- exchanger := make(testExchanger)
+ exchanger := newTestExchanger()
p.client = &dnsutil.Client{Exchanger: exchanger}
defer p.Close()
// No resolvers
@@ -161,11 +175,12 @@ func TestProxyWithResolvers(t *testing.T) {
m.Id = dns.Id()
m.SetQuestion("host1.", dns.TypeA)
m.Answer = reply.rr
- exchanger["resolver1"] = &response{answer: &m}
+ response1 := &response{answer: &m}
+ exchanger.setResponse("resolver1", response1)
assertRR(t, p, &m, "192.0.2.1")
// First and only resolver fails
- exchanger["resolver1"].fail = true
+ response1.fail = true
assertFailure(t, p, TypeA, "host1")
// First resolver fails, but second succeeds
@@ -175,18 +190,19 @@ func TestProxyWithResolvers(t *testing.T) {
m.Id = dns.Id()
m.SetQuestion("host1.", dns.TypeA)
m.Answer = reply.rr
- exchanger["resolver2"] = &response{answer: &m}
+ response2 := &response{answer: &m}
+ exchanger.setResponse("resolver2", response2)
assertRR(t, p, &m, "192.0.2.2")
// All resolvers fail
- exchanger["resolver2"].fail = true
+ response2.fail = true
assertFailure(t, p, TypeA, "host1")
}
func TestProxyWithCache(t *testing.T) {
p := testProxy(t)
p.cache = cache.New(10, nil)
- exchanger := make(testExchanger)
+ exchanger := newTestExchanger()
p.client = &dnsutil.Client{Exchanger: exchanger}
p.client.Addresses = []string{"resolver1"}
defer p.Close()
@@ -196,7 +212,7 @@ func TestProxyWithCache(t *testing.T) {
m.Id = dns.Id()
m.SetQuestion("host1.", dns.TypeA)
m.Answer = reply.rr
- exchanger["resolver1"] = &response{answer: &m}
+ exchanger.setResponse("resolver1", &response{answer: &m})
assertRR(t, p, &m, "192.0.2.1")
k := cache.NewKey("host1.", dns.TypeA, dns.ClassINET)
diff --git a/signal/signal_test.go b/signal/signal_test.go
index 544b9a0..7010ebd 100644
--- a/signal/signal_test.go
+++ b/signal/signal_test.go
@@ -3,6 +3,7 @@ package signal
import (
"io/ioutil"
"os"
+ "sync"
"syscall"
"testing"
"time"
@@ -11,18 +12,36 @@ import (
)
type reloaderCloser struct {
+ mu sync.RWMutex
reloaded bool
closed bool
}
-func (rc *reloaderCloser) Reload() { rc.reloaded = true }
+func (rc *reloaderCloser) Reload() {
+ rc.mu.Lock()
+ defer rc.mu.Unlock()
+ rc.reloaded = true
+}
func (rc *reloaderCloser) Close() error {
+ rc.mu.Lock()
+ defer rc.mu.Unlock()
rc.closed = true
return nil
}
-func (rc *reloaderCloser) isReloaded() bool { return rc.reloaded }
-func (rc *reloaderCloser) isClosed() bool { return rc.closed }
+func (rc *reloaderCloser) isReloaded() bool {
+ rc.mu.RLock()
+ defer rc.mu.RUnlock()
+ return rc.reloaded
+}
+func (rc *reloaderCloser) isClosed() bool {
+ rc.mu.RLock()
+ defer rc.mu.RUnlock()
+ return rc.closed
+}
+
func (rc *reloaderCloser) reset() {
+ rc.mu.Lock()
+ defer rc.mu.Unlock()
rc.reloaded = false
rc.closed = false
}