diff options
-rw-r--r-- | cache/cache.go | 25 | ||||
-rw-r--r-- | cache/cache_test.go | 25 |
2 files changed, 28 insertions, 22 deletions
diff --git a/cache/cache.go b/cache/cache.go index 014d3fe..47db598 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -9,14 +9,14 @@ import ( "github.com/miekg/dns" ) -// Cache represents a cache of DNS entries. Use New to initialize a new cache. +// Cache represents a cache of DNS responses. Use New to initialize a new cache. type Cache struct { capacity int now func() time.Time maintainer *maintainer mu sync.RWMutex wg sync.WaitGroup - entries map[uint32]*Value + values map[uint32]*Value keys []uint32 } @@ -61,7 +61,7 @@ type Value struct { // TTL returns the TTL of this cache value. func (v *Value) TTL() time.Duration { return minTTL(v.msg) } -// New creates a new cache of given capacity. Stale cache entries are removed at expiryInterval. +// New creates a new cache of given capacity. Stale cache values are removed every expiryInterval. func New(capacity int, expiryInterval time.Duration) *Cache { if capacity < 0 { capacity = 0 @@ -72,7 +72,7 @@ func New(capacity int, expiryInterval time.Duration) *Cache { cache := &Cache{ now: time.Now, capacity: capacity, - entries: make(map[uint32]*Value, capacity), + values: make(map[uint32]*Value, capacity), } maintain(cache, expiryInterval) return cache @@ -106,7 +106,7 @@ func (c *Cache) Get(k uint32) (*dns.Msg, bool) { func (c *Cache) getValue(k uint32) (*Value, bool) { c.mu.RLock() - v, ok := c.entries[k] + v, ok := c.values[k] c.mu.RUnlock() if !ok || c.isExpired(v) { return nil, false @@ -122,7 +122,10 @@ func (c *Cache) List(n int) []*Value { if len(values) == n { break } - v, _ := c.getValue(c.keys[i]) + v, ok := c.getValue(c.keys[i]) + if !ok { + continue + } values = append(values, v) } c.mu.RUnlock() @@ -140,11 +143,11 @@ func (c *Cache) Set(k uint32, msg *dns.Msg) { } now := c.now() c.mu.Lock() - if len(c.entries) == c.capacity && c.capacity > 0 { - delete(c.entries, c.keys[0]) + if len(c.values) == c.capacity && c.capacity > 0 { + delete(c.values, c.keys[0]) c.keys = c.keys[1:] } - c.entries[k] = &Value{ + c.values[k] = &Value{ Question: question(msg), Answers: answers(msg), Qtype: qtype(msg), @@ -175,9 +178,9 @@ func answers(msg *dns.Msg) []string { } func (c *Cache) deleteExpired() { c.mu.Lock() - for k, v := range c.entries { + for k, v := range c.values { if c.isExpired(v) { - delete(c.entries, k) + delete(c.values, k) } } c.mu.Unlock() diff --git a/cache/cache_test.go b/cache/cache_test.go index f2e00e7..71e5e58 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -66,7 +66,7 @@ func awaitExpiry(t *testing.T, c *Cache, k uint32) { now := time.Now() for { c.mu.RLock() - if _, ok := c.entries[k]; !ok { + if _, ok := c.values[k]; !ok { break } c.mu.RUnlock() @@ -131,7 +131,7 @@ func TestCache(t *testing.T) { if !tt.ok { awaitExpiry(t, c, k) } - if _, ok := c.entries[k]; ok != tt.ok { + if _, ok := c.values[k]; ok != tt.ok { t.Errorf("#%d: Cache[%d] = %t, want %t", i, k, ok, tt.ok) } } @@ -156,8 +156,8 @@ func TestCacheCapacity(t *testing.T) { msgs = append(msgs, m) c.Set(k, m) } - if got := len(c.entries); got != tt.size { - t.Errorf("#%d: len(entries) = %d, want %d", i, got, tt.size) + if got := len(c.values); got != tt.size { + t.Errorf("#%d: len(values) = %d, want %d", i, got, tt.size) } if tt.capacity > 0 && tt.addCount > tt.capacity && tt.capacity == tt.size { lastAdded := msgs[tt.addCount-1].Question[0] @@ -177,12 +177,14 @@ func TestCacheCapacity(t *testing.T) { func TestCacheList(t *testing.T) { var tests = []struct { addCount, listCount, wantCount int + expire bool }{ - {0, 0, 0}, - {1, 0, 0}, - {1, 1, 1}, - {2, 1, 1}, - {2, 3, 2}, + {0, 0, 0, false}, + {1, 0, 0, false}, + {1, 1, 1, false}, + {2, 1, 1, false}, + {2, 3, 2, false}, + {2, 0, 0, true}, } for i, tt := range tests { c := New(1024, 10*time.Minute) @@ -194,7 +196,9 @@ func TestCacheList(t *testing.T) { msgs = append(msgs, m) c.Set(k, m) } - + if tt.expire { + c.now = func() time.Time { return time.Now().Add(time.Minute).Add(time.Second) } + } values := c.List(tt.listCount) if got := len(values); got != tt.wantCount { t.Errorf("#%d: len(List(%d)) = %d, want %d", i, tt.listCount, got, tt.wantCount) @@ -203,7 +207,6 @@ func TestCacheList(t *testing.T) { for _, v := range values { gotMsgs = append(gotMsgs, v.msg) } - msgs = reverse(msgs) want := msgs[:tt.wantCount] if !reflect.DeepEqual(want, gotMsgs) { |