aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-30 15:21:07 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-30 15:21:07 +0100
commit5cc54e27335d24b51561818ea3f4a165f095e466 (patch)
tree3671bd43cde4d127dfdafd735f9e6b30a0f96ff5
parentbd0ec7519ebbfbd92a8c3063da0e8a823897171c (diff)
Implement cache prefetching
-rw-r--r--cache/cache.go81
-rw-r--r--cache/cache_test.go59
-rw-r--r--cmd/zdns/main.go8
-rw-r--r--dns/proxy.go2
-rw-r--r--dns/proxy_test.go4
-rw-r--r--http/http_test.go2
-rw-r--r--server_test.go2
7 files changed, 125 insertions, 33 deletions
diff --git a/cache/cache.go b/cache/cache.go
index a474170..3b650e9 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -12,6 +12,7 @@ import (
// Cache is a cache of DNS messages.
type Cache struct {
+ client *dnsutil.Client
capacity int
values map[uint64]*Value
keys []uint64
@@ -41,14 +42,18 @@ func (v *Value) Answers() []string { return dnsutil.Answers(v.msg) }
// TTL returns the time to live of the cached value v.
func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) }
-// New creates a new cache of given capacity.
-func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) }
+// New creates a new cache of given capacity. If client is non-nil, the cache will prefetch expired entries in an effort
+// to serve results faster.
+func New(capacity int, client *dnsutil.Client) *Cache {
+ return newCache(capacity, client, 10*time.Second, time.Now)
+}
-func newCache(capacity int, interval time.Duration, now func() time.Time) *Cache {
+func newCache(capacity int, client *dnsutil.Client, interval time.Duration, now func() time.Time) *Cache {
if capacity < 0 {
capacity = 0
}
cache := &Cache{
+ client: client,
now: now,
capacity: capacity,
values: make(map[uint64]*Value, capacity),
@@ -75,7 +80,11 @@ func maintain(cache *Cache, interval time.Duration) {
ticker.Stop()
return
case <-ticker.C:
- cache.evictExpired()
+ if cache.prefetch() {
+ cache.refreshExpired(interval)
+ } else {
+ cache.evictExpired()
+ }
}
}
}
@@ -100,7 +109,7 @@ func (c *Cache) getValue(k uint64) (*Value, bool) {
c.mu.RLock()
v, ok := c.values[k]
c.mu.RUnlock()
- if !ok || c.isExpired(v) {
+ if !ok || (!c.prefetch() && c.isExpired(v)) {
return nil, false
}
return v, true
@@ -152,34 +161,70 @@ func (c *Cache) Reset() {
c.keys = nil
}
+func (c *Cache) prefetch() bool { return c.client != nil }
+
+func (c *Cache) refreshExpired(interval time.Duration) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ evicted := make(map[uint64]bool)
+ for k, v := range c.values {
+ // Value will expiry before the next interval. Refresh now
+ if c.isExpiredAfter(interval, v) {
+ q := v.msg.Question[0]
+ msg := dns.Msg{}
+ msg.SetQuestion(q.Name, q.Qtype)
+ r, err := c.client.Exchange(&msg)
+ if err != nil {
+ continue // Will be retried on next run
+ }
+ if canCache(r) {
+ c.values[k].CreatedAt = c.now()
+ c.values[k].msg = r
+ } else {
+ // Can no longer be cached. Evict
+ delete(c.values, k)
+ evicted[k] = true
+ }
+ }
+ }
+ c.reorderKeys(evicted)
+}
+
func (c *Cache) evictExpired() {
c.mu.Lock()
defer c.mu.Unlock()
- evictedKeys := make(map[uint64]bool)
+ evicted := make(map[uint64]bool)
for k, v := range c.values {
if c.isExpired(v) {
delete(c.values, k)
- evictedKeys[k] = true
+ evicted[k] = true
}
}
- if len(evictedKeys) > 0 {
- // At least one entry was evicted. The ordered list of keys must be updated.
- var keys []uint64
- for _, k := range c.keys {
- if _, ok := evictedKeys[k]; ok {
- continue
- }
- keys = append(keys, k)
+ c.reorderKeys(evicted)
+}
+
+func (c *Cache) reorderKeys(evicted map[uint64]bool) {
+ if len(evicted) == 0 {
+ return
+ }
+ // At least one entry was evicted. The ordered list of keys must be updated.
+ var keys []uint64
+ for _, k := range c.keys {
+ if _, ok := evicted[k]; ok {
+ continue
}
- c.keys = keys
+ keys = append(keys, k)
}
+ c.keys = keys
}
-func (c *Cache) isExpired(v *Value) bool {
+func (c *Cache) isExpiredAfter(d time.Duration, v *Value) bool {
expiresAt := v.CreatedAt.Add(dnsutil.MinTTL(v.msg))
- return c.now().After(expiresAt)
+ return c.now().Add(d).After(expiresAt)
}
+func (c *Cache) isExpired(v *Value) bool { return c.isExpiredAfter(0, v) }
+
func canCache(msg *dns.Msg) bool {
if dnsutil.MinTTL(msg) == 0 {
return false
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 2ab082f..3ea6bf3 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -8,8 +8,17 @@ import (
"time"
"github.com/miekg/dns"
+ "github.com/mpolden/zdns/dns/dnsutil"
)
+type testExchanger struct {
+ answer *dns.Msg
+}
+
+func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) {
+ return e.answer, time.Second, nil
+}
+
func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
m := dns.Msg{}
m.Id = dns.Id()
@@ -80,7 +89,7 @@ func TestCache(t *testing.T) {
now := time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC)
nowFn := func() time.Time { return now }
- c := newCache(100, 10*time.Millisecond, nowFn)
+ c := newCache(100, nil, 10*time.Millisecond, nowFn)
defer c.Close()
var tests = []struct {
msg *dns.Msg
@@ -136,7 +145,7 @@ func TestCacheCapacity(t *testing.T) {
{3, 2, 2},
}
for i, tt := range tests {
- c := New(tt.capacity)
+ c := New(tt.capacity, nil)
defer c.Close()
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
@@ -176,7 +185,7 @@ func TestCacheList(t *testing.T) {
{2, 0, 0, true},
}
for i, tt := range tests {
- c := New(1024)
+ c := New(1024, nil)
defer c.Close()
var msgs []*dns.Msg
for i := 0; i < tt.addCount; i++ {
@@ -205,7 +214,7 @@ func TestCacheList(t *testing.T) {
}
func TestReset(t *testing.T) {
- c := New(10)
+ c := New(10, nil)
c.Set(uint64(1), &dns.Msg{})
c.Reset()
if got, want := len(c.values), 0; got != want {
@@ -216,6 +225,44 @@ func TestReset(t *testing.T) {
}
}
+func TestCachePrefetch(t *testing.T) {
+ exchanger := testExchanger{}
+ client := &dnsutil.Client{Exchanger: &exchanger, Addresses: []string{"resolver"}}
+ now := time.Now()
+ nowFn := func() time.Time { return now }
+ c := newCache(10, client, time.Hour, nowFn)
+
+ var key uint64 = 1
+ ip := net.ParseIP("192.0.2.1")
+ response := newA("r1.", 60, ip)
+ c.Set(key, response)
+
+ // Not refreshed yet
+ c.now = func() time.Time { return now.Add(30 * time.Second) }
+ c.refreshExpired(0)
+ rr, _ := c.Get(key)
+ answers := dnsutil.Answers(rr)
+ if got, want := answers[0], ip.String(); got != want {
+ t.Errorf("got ip %s, want %s", got, want)
+ }
+
+ // Expiry of cached value is ignored as prefetching is enabled
+ c.now = func() time.Time { return now.Add(61 * time.Second) }
+ if _, ok := c.Get(key); !ok {
+ t.Errorf("Get(%d) = (_, %t), want (_, %t)", key, ok, !ok)
+ }
+
+ // Refresh expired entry
+ ip = net.ParseIP("192.0.2.2")
+ exchanger.answer = newA("r1.", 60, ip)
+ c.refreshExpired(0)
+ rr, _ = c.Get(key)
+ answers = dnsutil.Answers(rr)
+ if got, want := answers[0], ip.String(); got != want {
+ t.Errorf("got ip %s, want %s", got, want)
+ }
+}
+
func BenchmarkNewKey(b *testing.B) {
for n := 0; n < b.N; n++ {
NewKey("key", 1, 1)
@@ -223,7 +270,7 @@ func BenchmarkNewKey(b *testing.B) {
}
func BenchmarkCache(b *testing.B) {
- c := New(1000)
+ c := New(1000, nil)
b.ResetTimer()
for n := 0; n < b.N; n++ {
c.Set(uint64(n), &dns.Msg{})
@@ -232,7 +279,7 @@ func BenchmarkCache(b *testing.B) {
}
func BenchmarkCacheEviction(b *testing.B) {
- c := New(1)
+ c := New(1, nil)
b.ResetTimer()
for n := 0; n < b.N; n++ {
c.Set(uint64(n), &dns.Msg{})
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go
index 4b2697d..fef8455 100644
--- a/cmd/zdns/main.go
+++ b/cmd/zdns/main.go
@@ -90,13 +90,13 @@ func (c *cli) run() {
sigHandler := signal.NewHandler(c.signal, logger)
sigHandler.OnClose(logger)
- // Cache
- cache := cache.New(config.DNS.CacheSize)
- sigHandler.OnClose(cache)
-
// Client
client := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...)
+ // Cache
+ cache := cache.New(config.DNS.CacheSize, nil)
+ sigHandler.OnClose(cache)
+
// DNS server
proxy, err := dns.NewProxy(cache, client, logger)
fatal(err)
diff --git a/dns/proxy.go b/dns/proxy.go
index 060cb61..9daf552 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -143,8 +143,8 @@ func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
rr, err := p.client.Exchange(r)
if err == nil {
- p.cache.Set(key, rr)
p.writeMsg(w, rr, false)
+ p.cache.Set(key, rr)
} else {
p.logger.Printf("resolver(s) failed: %s", err)
dns.HandleFailed(w, r)
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index 5fa0646..f9783bf 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -57,7 +57,7 @@ func testProxy(t *testing.T) *Proxy {
if err != nil {
t.Fatal(err)
}
- proxy, err := NewProxy(cache.New(0), nil, log)
+ proxy, err := NewProxy(cache.New(0, nil), nil, log)
if err != nil {
t.Fatal(err)
}
@@ -185,7 +185,7 @@ func TestProxyWithResolvers(t *testing.T) {
func TestProxyWithCache(t *testing.T) {
p := testProxy(t)
- p.cache = cache.New(10)
+ p.cache = cache.New(10, nil)
exchanger := make(testExchanger)
p.client = &dnsutil.Client{Exchanger: exchanger}
p.client.Addresses = []string{"resolver1"}
diff --git a/http/http_test.go b/http/http_test.go
index 52e4eaa..8fd332b 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -34,7 +34,7 @@ func testServer() (*httptest.Server, *Server) {
if err != nil {
panic(err)
}
- cache := cache.New(10)
+ cache := cache.New(10, nil)
server := Server{logger: logger, cache: cache}
return httptest.NewServer(server.handler()), &server
}
diff --git a/server_test.go b/server_test.go
index 657f68f..17883e3 100644
--- a/server_test.go
+++ b/server_test.go
@@ -106,7 +106,7 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) {
if err != nil {
t.Fatal(err)
}
- proxy, err := dns.NewProxy(cache.New(0), nil, logger)
+ proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger)
if err != nil {
t.Fatal(err)
}