aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-26 19:20:59 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-26 19:23:04 +0100
commitf031924851c4121117b70d97d7d507e5bd593fb1 (patch)
tree02a8e4815e4552f71aca687803cb570b71bc6946
parentb1dc1672ee9c0414fa538b03edec8442cc021ec1 (diff)
Listing cache should not include expired values
-rw-r--r--cache/cache.go25
-rw-r--r--cache/cache_test.go25
2 files changed, 28 insertions, 22 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 014d3fe..47db598 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -9,14 +9,14 @@ import (
"github.com/miekg/dns"
)
-// Cache represents a cache of DNS entries. Use New to initialize a new cache.
+// Cache represents a cache of DNS responses. Use New to initialize a new cache.
type Cache struct {
capacity int
now func() time.Time
maintainer *maintainer
mu sync.RWMutex
wg sync.WaitGroup
- entries map[uint32]*Value
+ values map[uint32]*Value
keys []uint32
}
@@ -61,7 +61,7 @@ type Value struct {
// TTL returns the TTL of this cache value.
func (v *Value) TTL() time.Duration { return minTTL(v.msg) }
-// New creates a new cache of given capacity. Stale cache entries are removed at expiryInterval.
+// New creates a new cache of given capacity. Stale cache values are removed every expiryInterval.
func New(capacity int, expiryInterval time.Duration) *Cache {
if capacity < 0 {
capacity = 0
@@ -72,7 +72,7 @@ func New(capacity int, expiryInterval time.Duration) *Cache {
cache := &Cache{
now: time.Now,
capacity: capacity,
- entries: make(map[uint32]*Value, capacity),
+ values: make(map[uint32]*Value, capacity),
}
maintain(cache, expiryInterval)
return cache
@@ -106,7 +106,7 @@ func (c *Cache) Get(k uint32) (*dns.Msg, bool) {
func (c *Cache) getValue(k uint32) (*Value, bool) {
c.mu.RLock()
- v, ok := c.entries[k]
+ v, ok := c.values[k]
c.mu.RUnlock()
if !ok || c.isExpired(v) {
return nil, false
@@ -122,7 +122,10 @@ func (c *Cache) List(n int) []*Value {
if len(values) == n {
break
}
- v, _ := c.getValue(c.keys[i])
+ v, ok := c.getValue(c.keys[i])
+ if !ok {
+ continue
+ }
values = append(values, v)
}
c.mu.RUnlock()
@@ -140,11 +143,11 @@ func (c *Cache) Set(k uint32, msg *dns.Msg) {
}
now := c.now()
c.mu.Lock()
- if len(c.entries) == c.capacity && c.capacity > 0 {
- delete(c.entries, c.keys[0])
+ if len(c.values) == c.capacity && c.capacity > 0 {
+ delete(c.values, c.keys[0])
c.keys = c.keys[1:]
}
- c.entries[k] = &Value{
+ c.values[k] = &Value{
Question: question(msg),
Answers: answers(msg),
Qtype: qtype(msg),
@@ -175,9 +178,9 @@ func answers(msg *dns.Msg) []string {
}
func (c *Cache) deleteExpired() {
c.mu.Lock()
- for k, v := range c.entries {
+ for k, v := range c.values {
if c.isExpired(v) {
- delete(c.entries, k)
+ delete(c.values, k)
}
}
c.mu.Unlock()
diff --git a/cache/cache_test.go b/cache/cache_test.go
index f2e00e7..71e5e58 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -66,7 +66,7 @@ func awaitExpiry(t *testing.T, c *Cache, k uint32) {
now := time.Now()
for {
c.mu.RLock()
- if _, ok := c.entries[k]; !ok {
+ if _, ok := c.values[k]; !ok {
break
}
c.mu.RUnlock()
@@ -131,7 +131,7 @@ func TestCache(t *testing.T) {
if !tt.ok {
awaitExpiry(t, c, k)
}
- if _, ok := c.entries[k]; ok != tt.ok {
+ if _, ok := c.values[k]; ok != tt.ok {
t.Errorf("#%d: Cache[%d] = %t, want %t", i, k, ok, tt.ok)
}
}
@@ -156,8 +156,8 @@ func TestCacheCapacity(t *testing.T) {
msgs = append(msgs, m)
c.Set(k, m)
}
- if got := len(c.entries); got != tt.size {
- t.Errorf("#%d: len(entries) = %d, want %d", i, got, tt.size)
+ if got := len(c.values); got != tt.size {
+ t.Errorf("#%d: len(values) = %d, want %d", i, got, tt.size)
}
if tt.capacity > 0 && tt.addCount > tt.capacity && tt.capacity == tt.size {
lastAdded := msgs[tt.addCount-1].Question[0]
@@ -177,12 +177,14 @@ func TestCacheCapacity(t *testing.T) {
func TestCacheList(t *testing.T) {
var tests = []struct {
addCount, listCount, wantCount int
+ expire bool
}{
- {0, 0, 0},
- {1, 0, 0},
- {1, 1, 1},
- {2, 1, 1},
- {2, 3, 2},
+ {0, 0, 0, false},
+ {1, 0, 0, false},
+ {1, 1, 1, false},
+ {2, 1, 1, false},
+ {2, 3, 2, false},
+ {2, 0, 0, true},
}
for i, tt := range tests {
c := New(1024, 10*time.Minute)
@@ -194,7 +196,9 @@ func TestCacheList(t *testing.T) {
msgs = append(msgs, m)
c.Set(k, m)
}
-
+ if tt.expire {
+ c.now = func() time.Time { return time.Now().Add(time.Minute).Add(time.Second) }
+ }
values := c.List(tt.listCount)
if got := len(values); got != tt.wantCount {
t.Errorf("#%d: len(List(%d)) = %d, want %d", i, tt.listCount, got, tt.wantCount)
@@ -203,7 +207,6 @@ func TestCacheList(t *testing.T) {
for _, v := range values {
gotMsgs = append(gotMsgs, v.msg)
}
-
msgs = reverse(msgs)
want := msgs[:tt.wantCount]
if !reflect.DeepEqual(want, gotMsgs) {