From 5cc54e27335d24b51561818ea3f4a165f095e466 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Mon, 30 Dec 2019 15:21:07 +0100 Subject: Implement cache prefetching --- cache/cache.go | 81 +++++++++++++++++++++++++++++++++++++++++------------ cache/cache_test.go | 59 ++++++++++++++++++++++++++++++++++---- cmd/zdns/main.go | 8 +++--- dns/proxy.go | 2 +- dns/proxy_test.go | 4 +-- http/http_test.go | 2 +- server_test.go | 2 +- 7 files changed, 125 insertions(+), 33 deletions(-) diff --git a/cache/cache.go b/cache/cache.go index a474170..3b650e9 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -12,6 +12,7 @@ import ( // Cache is a cache of DNS messages. type Cache struct { + client *dnsutil.Client capacity int values map[uint64]*Value keys []uint64 @@ -41,14 +42,18 @@ func (v *Value) Answers() []string { return dnsutil.Answers(v.msg) } // TTL returns the time to live of the cached value v. func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) } -// New creates a new cache of given capacity. -func New(capacity int) *Cache { return newCache(capacity, time.Minute, time.Now) } +// New creates a new cache of given capacity. If client is non-nil, the cache will prefetch expired entries in an effort +// to serve results faster. +func New(capacity int, client *dnsutil.Client) *Cache { + return newCache(capacity, client, 10*time.Second, time.Now) +} -func newCache(capacity int, interval time.Duration, now func() time.Time) *Cache { +func newCache(capacity int, client *dnsutil.Client, interval time.Duration, now func() time.Time) *Cache { if capacity < 0 { capacity = 0 } cache := &Cache{ + client: client, now: now, capacity: capacity, values: make(map[uint64]*Value, capacity), @@ -75,7 +80,11 @@ func maintain(cache *Cache, interval time.Duration) { ticker.Stop() return case <-ticker.C: - cache.evictExpired() + if cache.prefetch() { + cache.refreshExpired(interval) + } else { + cache.evictExpired() + } } } } @@ -100,7 +109,7 @@ func (c *Cache) getValue(k uint64) (*Value, bool) { c.mu.RLock() v, ok := c.values[k] c.mu.RUnlock() - if !ok || c.isExpired(v) { + if !ok || (!c.prefetch() && c.isExpired(v)) { return nil, false } return v, true @@ -152,34 +161,70 @@ func (c *Cache) Reset() { c.keys = nil } +func (c *Cache) prefetch() bool { return c.client != nil } + +func (c *Cache) refreshExpired(interval time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + evicted := make(map[uint64]bool) + for k, v := range c.values { + // Value will expiry before the next interval. Refresh now + if c.isExpiredAfter(interval, v) { + q := v.msg.Question[0] + msg := dns.Msg{} + msg.SetQuestion(q.Name, q.Qtype) + r, err := c.client.Exchange(&msg) + if err != nil { + continue // Will be retried on next run + } + if canCache(r) { + c.values[k].CreatedAt = c.now() + c.values[k].msg = r + } else { + // Can no longer be cached. Evict + delete(c.values, k) + evicted[k] = true + } + } + } + c.reorderKeys(evicted) +} + func (c *Cache) evictExpired() { c.mu.Lock() defer c.mu.Unlock() - evictedKeys := make(map[uint64]bool) + evicted := make(map[uint64]bool) for k, v := range c.values { if c.isExpired(v) { delete(c.values, k) - evictedKeys[k] = true + evicted[k] = true } } - if len(evictedKeys) > 0 { - // At least one entry was evicted. The ordered list of keys must be updated. - var keys []uint64 - for _, k := range c.keys { - if _, ok := evictedKeys[k]; ok { - continue - } - keys = append(keys, k) + c.reorderKeys(evicted) +} + +func (c *Cache) reorderKeys(evicted map[uint64]bool) { + if len(evicted) == 0 { + return + } + // At least one entry was evicted. The ordered list of keys must be updated. + var keys []uint64 + for _, k := range c.keys { + if _, ok := evicted[k]; ok { + continue } - c.keys = keys + keys = append(keys, k) } + c.keys = keys } -func (c *Cache) isExpired(v *Value) bool { +func (c *Cache) isExpiredAfter(d time.Duration, v *Value) bool { expiresAt := v.CreatedAt.Add(dnsutil.MinTTL(v.msg)) - return c.now().After(expiresAt) + return c.now().Add(d).After(expiresAt) } +func (c *Cache) isExpired(v *Value) bool { return c.isExpiredAfter(0, v) } + func canCache(msg *dns.Msg) bool { if dnsutil.MinTTL(msg) == 0 { return false diff --git a/cache/cache_test.go b/cache/cache_test.go index 2ab082f..3ea6bf3 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -8,8 +8,17 @@ import ( "time" "github.com/miekg/dns" + "github.com/mpolden/zdns/dns/dnsutil" ) +type testExchanger struct { + answer *dns.Msg +} + +func (e *testExchanger) Exchange(msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { + return e.answer, time.Second, nil +} + func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { m := dns.Msg{} m.Id = dns.Id() @@ -80,7 +89,7 @@ func TestCache(t *testing.T) { now := time.Date(2019, 1, 1, 0, 0, 0, 0, time.UTC) nowFn := func() time.Time { return now } - c := newCache(100, 10*time.Millisecond, nowFn) + c := newCache(100, nil, 10*time.Millisecond, nowFn) defer c.Close() var tests = []struct { msg *dns.Msg @@ -136,7 +145,7 @@ func TestCacheCapacity(t *testing.T) { {3, 2, 2}, } for i, tt := range tests { - c := New(tt.capacity) + c := New(tt.capacity, nil) defer c.Close() var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { @@ -176,7 +185,7 @@ func TestCacheList(t *testing.T) { {2, 0, 0, true}, } for i, tt := range tests { - c := New(1024) + c := New(1024, nil) defer c.Close() var msgs []*dns.Msg for i := 0; i < tt.addCount; i++ { @@ -205,7 +214,7 @@ func TestCacheList(t *testing.T) { } func TestReset(t *testing.T) { - c := New(10) + c := New(10, nil) c.Set(uint64(1), &dns.Msg{}) c.Reset() if got, want := len(c.values), 0; got != want { @@ -216,6 +225,44 @@ func TestReset(t *testing.T) { } } +func TestCachePrefetch(t *testing.T) { + exchanger := testExchanger{} + client := &dnsutil.Client{Exchanger: &exchanger, Addresses: []string{"resolver"}} + now := time.Now() + nowFn := func() time.Time { return now } + c := newCache(10, client, time.Hour, nowFn) + + var key uint64 = 1 + ip := net.ParseIP("192.0.2.1") + response := newA("r1.", 60, ip) + c.Set(key, response) + + // Not refreshed yet + c.now = func() time.Time { return now.Add(30 * time.Second) } + c.refreshExpired(0) + rr, _ := c.Get(key) + answers := dnsutil.Answers(rr) + if got, want := answers[0], ip.String(); got != want { + t.Errorf("got ip %s, want %s", got, want) + } + + // Expiry of cached value is ignored as prefetching is enabled + c.now = func() time.Time { return now.Add(61 * time.Second) } + if _, ok := c.Get(key); !ok { + t.Errorf("Get(%d) = (_, %t), want (_, %t)", key, ok, !ok) + } + + // Refresh expired entry + ip = net.ParseIP("192.0.2.2") + exchanger.answer = newA("r1.", 60, ip) + c.refreshExpired(0) + rr, _ = c.Get(key) + answers = dnsutil.Answers(rr) + if got, want := answers[0], ip.String(); got != want { + t.Errorf("got ip %s, want %s", got, want) + } +} + func BenchmarkNewKey(b *testing.B) { for n := 0; n < b.N; n++ { NewKey("key", 1, 1) @@ -223,7 +270,7 @@ func BenchmarkNewKey(b *testing.B) { } func BenchmarkCache(b *testing.B) { - c := New(1000) + c := New(1000, nil) b.ResetTimer() for n := 0; n < b.N; n++ { c.Set(uint64(n), &dns.Msg{}) @@ -232,7 +279,7 @@ func BenchmarkCache(b *testing.B) { } func BenchmarkCacheEviction(b *testing.B) { - c := New(1) + c := New(1, nil) b.ResetTimer() for n := 0; n < b.N; n++ { c.Set(uint64(n), &dns.Msg{}) diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 4b2697d..fef8455 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -90,13 +90,13 @@ func (c *cli) run() { sigHandler := signal.NewHandler(c.signal, logger) sigHandler.OnClose(logger) - // Cache - cache := cache.New(config.DNS.CacheSize) - sigHandler.OnClose(cache) - // Client client := dnsutil.NewClient(config.Resolver.Protocol, config.Resolver.Timeout, config.DNS.Resolvers...) + // Cache + cache := cache.New(config.DNS.CacheSize, nil) + sigHandler.OnClose(cache) + // DNS server proxy, err := dns.NewProxy(cache, client, logger) fatal(err) diff --git a/dns/proxy.go b/dns/proxy.go index 060cb61..9daf552 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -143,8 +143,8 @@ func (p *Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } rr, err := p.client.Exchange(r) if err == nil { - p.cache.Set(key, rr) p.writeMsg(w, rr, false) + p.cache.Set(key, rr) } else { p.logger.Printf("resolver(s) failed: %s", err) dns.HandleFailed(w, r) diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 5fa0646..f9783bf 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -57,7 +57,7 @@ func testProxy(t *testing.T) *Proxy { if err != nil { t.Fatal(err) } - proxy, err := NewProxy(cache.New(0), nil, log) + proxy, err := NewProxy(cache.New(0, nil), nil, log) if err != nil { t.Fatal(err) } @@ -185,7 +185,7 @@ func TestProxyWithResolvers(t *testing.T) { func TestProxyWithCache(t *testing.T) { p := testProxy(t) - p.cache = cache.New(10) + p.cache = cache.New(10, nil) exchanger := make(testExchanger) p.client = &dnsutil.Client{Exchanger: exchanger} p.client.Addresses = []string{"resolver1"} diff --git a/http/http_test.go b/http/http_test.go index 52e4eaa..8fd332b 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -34,7 +34,7 @@ func testServer() (*httptest.Server, *Server) { if err != nil { panic(err) } - cache := cache.New(10) + cache := cache.New(10, nil) server := Server{logger: logger, cache: cache} return httptest.NewServer(server.handler()), &server } diff --git a/server_test.go b/server_test.go index 657f68f..17883e3 100644 --- a/server_test.go +++ b/server_test.go @@ -106,7 +106,7 @@ func testServer(t *testing.T, refreshInterval time.Duration) (*Server, func()) { if err != nil { t.Fatal(err) } - proxy, err := dns.NewProxy(cache.New(0), nil, logger) + proxy, err := dns.NewProxy(cache.New(0, nil), nil, logger) if err != nil { t.Fatal(err) } -- cgit v1.2.3