aboutsummaryrefslogtreecommitdiffstats
path: root/cache/cache_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'cache/cache_test.go')
-rw-r--r--cache/cache_test.go70
1 files changed, 39 insertions, 31 deletions
diff --git a/cache/cache_test.go b/cache/cache_test.go
index df05721..77f9f15 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -10,12 +10,6 @@ import (
"github.com/miekg/dns"
)
-func handleErr(t *testing.T, fn func() error) {
- if err := fn(); err != nil {
- t.Fatal(err)
- }
-}
-
func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
m := dns.Msg{}
m.Id = dns.Id()
@@ -66,10 +60,11 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) {
now := time.Now()
for {
c.mu.RLock()
- if _, ok := c.values[k]; !ok {
+ _, ok := c.values[k]
+ c.mu.RUnlock()
+ if !ok {
break
}
- c.mu.RUnlock()
time.Sleep(10 * time.Millisecond)
if time.Since(now) > 2*time.Second {
t.Fatalf("timed out waiting for expiry of key %d", k)
@@ -78,30 +73,45 @@ func awaitExpiry(t *testing.T, c *Cache, k uint64) {
}
func TestCache(t *testing.T) {
- msg := newA("foo.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2"))
- msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2"))
- msgFailure := newA("baz.", 60, net.ParseIP("192.0.2.2"))
+ msg := newA("r1.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2"))
+ msgWithZeroTTL := newA("r2.", 0, net.ParseIP("192.0.2.2"))
+ msgFailure := newA("r3.", 60, net.ParseIP("192.0.2.2"))
msgFailure.Rcode = dns.RcodeServerFailure
+ msgNameError := &dns.Msg{}
+ msgNameError.Id = dns.Id()
+ msgNameError.SetQuestion(dns.Fqdn("r4."), dns.TypeA)
+ msgNameError.Rcode = dns.RcodeNameError
+ msgLowerNsTTL := newA("r5.", 60, net.ParseIP("192.0.2.1"))
+ msgLowerNsTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r5.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 30}}}
+ msgLowerExtraTTL := newA("r6.", 3600, net.ParseIP("192.0.2.1"))
+ msgLowerExtraTTL.Ns = []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: 60}}}
+ msgLowerExtraTTL.Extra = []dns.RR{
+ &dns.OPT{Hdr: dns.RR_Header{Name: "EDNS", Rrtype: dns.TypeOPT, Class: dns.ClassINET, Ttl: 10}}, // Ignored
+ &dns.A{Hdr: dns.RR_Header{Name: "ns1.r6.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 30}},
+ }
- createdAt := date(2019, 1, 1)
- c := New(100)
- c.interval = 10 * time.Millisecond
- defer handleErr(t, c.Close)
+ now := date(2019, 1, 1)
+ nowFn := func() time.Time { return now }
+ c := newCache(100, 10*time.Millisecond, nowFn)
+ defer c.Close()
var tests = []struct {
- msg *dns.Msg
- createdAt, queriedAt time.Time
- ok bool
- value *Value
+ msg *dns.Msg
+ queriedAt time.Time
+ ok bool
+ value *Value
}{
- {msg, createdAt, createdAt, true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired when query time == create time
- {msg, createdAt, createdAt.Add(30 * time.Second), true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired when below TTL
- {msg, createdAt, createdAt.Add(60 * time.Second), true, &Value{CreatedAt: createdAt, msg: msg}}, // Not expired until TTL exceeds
- {msg, createdAt, createdAt.Add(61 * time.Second), false, nil}, // Expired
- {msgWithZeroTTL, createdAt, createdAt, false, nil}, // 0 TTL is not cached
- {msgFailure, createdAt, createdAt, false, nil}, // Non-cacheable rcode
+ {msg, now, true, &Value{CreatedAt: now, msg: msg}}, // Not expired when query time == create time
+ {msg, now.Add(30 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired when below TTL
+ {msg, now.Add(60 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds
+ {msgNameError, now, true, &Value{CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached
+ {msg, now.Add(61 * time.Second), false, nil}, // Expired due to answer TTL
+ {msgLowerNsTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower NS TTL
+ {msgLowerExtraTTL, now.Add(31 * time.Second), false, nil}, // Expired due to lower Extra TTL
+ {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached
+ {msgFailure, now, false, nil}, // Non-cacheable rcode
}
for i, tt := range tests {
- c.now = func() time.Time { return tt.createdAt }
+ c.now = nowFn
k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass)
c.Set(k, tt.msg)
c.now = func() time.Time { return tt.queriedAt }
@@ -125,7 +135,7 @@ func TestCache(t *testing.T) {
}
}
if (keyIdx != -1) != tt.ok {
- t.Errorf("#%d: keys[%d] = %d, should not exist", i, keyIdx, k)
+ t.Errorf("#%d: keys[%d] = %d, got expired key", i, keyIdx, k)
}
}
}
@@ -141,8 +151,7 @@ func TestCacheCapacity(t *testing.T) {
}
for i, tt := range tests {
c := New(tt.capacity)
- c.interval = 10 * time.Millisecond
- defer handleErr(t, c.Close)
+ defer c.Close()
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i)))
@@ -182,8 +191,7 @@ func TestCacheList(t *testing.T) {
}
for i, tt := range tests {
c := New(1024)
- c.interval = 10 * time.Millisecond
- defer handleErr(t, c.Close)
+ defer c.Close()
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i)))