aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2020-01-10 19:44:38 +0100
committerMartin Polden <mpolden@mpolden.no>2020-01-11 15:14:13 +0100
commitbf8b103057d1492f62e561a0124c8ad7bce36af9 (patch)
treed7758ae7c9d0616d7d714cb34499ccfbc92236cc
parent2e1aff9657b777777e8b32f374315989a111b062 (diff)
Add support for cache backend
-rw-r--r--cache/cache.go69
-rw-r--r--cache/cache_test.go74
2 files changed, 132 insertions, 11 deletions
diff --git a/cache/cache.go b/cache/cache.go
index 13b3857..a31aeff 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -14,9 +14,25 @@ import (
"github.com/mpolden/zdns/dns/dnsutil"
)
+// Backend is the interface for a persistent cache backend.
+type Backend interface {
+ Set(key uint32, value Value)
+ Evict(key uint32)
+ Read() []Value
+ Reset()
+}
+
+type defaultBackend struct{}
+
+func (b *defaultBackend) Set(uint32, Value) {}
+func (b *defaultBackend) Evict(uint32) {}
+func (b *defaultBackend) Read() []Value { return nil }
+func (b *defaultBackend) Reset() {}
+
// Cache is a cache of DNS messages.
type Cache struct {
client *dnsutil.Client
+ backend Backend
capacity int
values map[uint32]Value
keys []uint32
@@ -93,20 +109,29 @@ func Unpack(value string) (Value, error) {
// New creates a new cache of given capacity.
//
// If client is non-nil, the cache will prefetch expired entries in an effort to serve results faster.
+// If backend is non-nil, the cache will use it to persist cache entries.
func New(capacity int, client *dnsutil.Client) *Cache {
- return newCache(capacity, client, time.Now)
+ return NewWithBackend(capacity, client, &defaultBackend{})
}
-func newCache(capacity int, client *dnsutil.Client, now func() time.Time) *Cache {
+// NewWithBackend creates a new cache that forwards entries to backend.
+func NewWithBackend(capacity int, client *dnsutil.Client, backend Backend) *Cache {
+ return newCache(capacity, client, backend, time.Now)
+}
+
+func newCache(capacity int, client *dnsutil.Client, backend Backend, now func() time.Time) *Cache {
if capacity < 0 {
capacity = 0
}
- return &Cache{
+ c := &Cache{
client: client,
+ backend: &defaultBackend{},
now: now,
capacity: capacity,
values: make(map[uint32]Value, capacity),
}
+ c.load(backend)
+ return c
}
// NewKey creates a new cache key for the DNS name, qtype and qclass
@@ -118,6 +143,29 @@ func NewKey(name string, qtype, qclass uint16) uint32 {
return h.Sum32()
}
+func (c *Cache) load(backend Backend) {
+ if c.capacity == 0 {
+ backend.Reset()
+ return
+ }
+ values := backend.Read()
+ n := 0
+ if c.capacity < len(values) {
+ n = c.capacity
+ }
+ // Add the last n values from backend
+ for _, v := range values[n:] {
+ c.setValue(v)
+ }
+ if c.capacity < len(values) {
+ // Remove older entries from backend
+ for _, v := range values[:n] {
+ backend.Evict(v.Key)
+ }
+ }
+ c.backend = backend
+}
+
// Get returns the DNS message associated with key.
func (c *Cache) Get(key uint32) (*dns.Msg, bool) {
v, ok := c.getValue(key)
@@ -177,19 +225,22 @@ func (c *Cache) Set(key uint32, msg *dns.Msg) {
}
func (c *Cache) set(key uint32, msg *dns.Msg) bool {
- return c.setValue(key, Value{Key: key, CreatedAt: c.now(), msg: msg})
+ return c.setValue(Value{Key: key, CreatedAt: c.now(), msg: msg})
}
-func (c *Cache) setValue(key uint32, value Value) bool {
+func (c *Cache) setValue(value Value) bool {
if c.capacity == 0 || !canCache(value.msg) {
return false
}
if len(c.values) == c.capacity && c.capacity > 0 {
- delete(c.values, c.keys[0])
+ evict := c.keys[0]
+ delete(c.values, evict)
c.keys = c.keys[1:]
+ c.backend.Evict(evict)
}
- c.values[key] = value
- c.appendKey(key)
+ c.values[value.Key] = value
+ c.appendKey(value.Key)
+ c.backend.Set(value.Key, value)
return true
}
@@ -199,6 +250,7 @@ func (c *Cache) Reset() {
defer c.mu.Unlock()
c.values = make(map[uint32]Value)
c.keys = nil
+ c.backend.Reset()
}
func (c *Cache) prefetch() bool { return c.client != nil }
@@ -227,6 +279,7 @@ func (c *Cache) evictWithLock(key uint32) {
func (c *Cache) evict(key uint32) {
delete(c.values, key)
c.removeKey(key)
+ c.backend.Evict(key)
}
func (c *Cache) appendKey(key uint32) {
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 4d1d16a..346df25 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -40,6 +40,29 @@ func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Dura
return <-e.answers, time.Second, nil
}
+type testBackend struct {
+ values []Value
+}
+
+func (b *testBackend) Set(key uint32, value Value) {
+ b.values = append(b.values, value)
+}
+
+func (b *testBackend) Evict(key uint32) {
+ var values []Value
+ for _, v := range b.values {
+ if v.Key == key {
+ continue
+ }
+ values = append(values, v)
+ }
+ b.values = values
+}
+
+func (b *testBackend) Reset() { b.values = nil }
+
+func (b *testBackend) Read() []Value { return b.values }
+
func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
m := dns.Msg{}
m.Id = dns.Id()
@@ -127,7 +150,7 @@ func TestCache(t *testing.T) {
now := time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)
nowFn := func() time.Time { return now }
- c := newCache(100, nil, nowFn)
+ c := newCache(100, nil, &defaultBackend{}, nowFn)
var tests = []struct {
msg *dns.Msg
queriedAt time.Time
@@ -266,7 +289,7 @@ func TestCachePrefetch(t *testing.T) {
exchanger := newTestExchanger()
client := &dnsutil.Client{Exchanger: exchanger, Addresses: []string{"resolver"}}
now := time.Now()
- c := newCache(10, client, func() time.Time { return now })
+ c := newCache(10, client, &defaultBackend{}, func() time.Time { return now })
var tests = []struct {
initialAnswer string
@@ -326,7 +349,7 @@ func TestCacheEvictAndUpdate(t *testing.T) {
exchanger := newTestExchanger()
client := &dnsutil.Client{Exchanger: exchanger, Addresses: []string{"resolver"}}
now := time.Now()
- c := newCache(10, client, func() time.Time { return now })
+ c := newCache(10, client, &defaultBackend{}, func() time.Time { return now })
msg := newA("example.com.", 60, net.ParseIP("192.0.2.1"))
var key uint32 = 1
@@ -385,6 +408,51 @@ func TestPackValue(t *testing.T) {
}
}
+func TestCacheWithBackend(t *testing.T) {
+ var tests = []struct {
+ capacity int
+ backendSize int
+ cacheSize int
+ }{
+ {0, 0, 0},
+ {0, 1, 0},
+ {1, 0, 0},
+ {1, 1, 1},
+ {1, 2, 1},
+ {2, 1, 1},
+ {2, 2, 2},
+ {3, 2, 2},
+ }
+ for i, tt := range tests {
+ backend := &testBackend{}
+ for j := 0; j < tt.backendSize; j++ {
+ v := Value{
+ Key: uint32(j),
+ CreatedAt: time.Now(),
+ msg: newA("example.com.", 60, net.ParseIP("192.0.2.1")),
+ }
+ backend.Set(v.Key, v)
+ }
+ c := NewWithBackend(tt.capacity, nil, backend)
+ if got, want := len(c.values), tt.cacheSize; got != want {
+ t.Errorf("#%d: len(values) = %d, want %d", i, got, want)
+ }
+ if tt.backendSize > tt.capacity {
+ if got, want := len(backend.Read()), tt.capacity; got != want {
+ t.Errorf("#%d: len(backend.Read()) = %d, want %d", i, got, want)
+ }
+ }
+ if tt.capacity == tt.backendSize {
+ // Adding a new entry to a cache at capacity removes the oldest from backend
+ msg := newA("example.com.", 60, net.ParseIP("192.0.2.1"))
+ c.Set(42, msg)
+ if got, want := len(backend.Read()), tt.capacity; got != want {
+ t.Errorf("#%d: len(backend.Read()) = %d, want %d", i, got, want)
+ }
+ }
+ }
+}
+
func BenchmarkNewKey(b *testing.B) {
for n := 0; n < b.N; n++ {
NewKey("key", 1, 1)