diff options
author | Martin Polden <mpolden@mpolden.no> | 2019-06-09 20:37:34 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2019-06-09 20:52:29 +0200 |
commit | 419c54374f6b10bd10e9a9b751e76c36a5e2a804 (patch) | |
tree | 43506f882cd27d675f63e8d1bb2c52f74f6c3685 /cache | |
parent | d7db6bf84a684fc3027b5f016aacefebe8775370 (diff) |
Implement cache
Diffstat (limited to 'cache')
-rw-r--r-- | cache/cache.go | 95 | ||||
-rw-r--r-- | cache/cache_test.go | 116 |
2 files changed, 211 insertions, 0 deletions
diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..f2fe82d --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,95 @@ +package cache + +import ( + "encoding/binary" + "fmt" + "hash/fnv" + "sync" + "time" + + "github.com/miekg/dns" +) + +// Cache represents a cache of DNS entries +type Cache struct { + maxSize int + mu sync.RWMutex + entries map[uint32]value + keys []uint32 +} + +type value struct { + msg dns.Msg + createdAt time.Time +} + +// New creates a new cache with the given maximum size. Adding a key to a cache of this size removes the oldest key. +func New(maxSize int) (*Cache, error) { + if maxSize < 0 { + return nil, fmt.Errorf("invalid cache size: %d", maxSize) + } + return &Cache{ + maxSize: maxSize, + entries: make(map[uint32]value), + }, nil +} + +// NewKey creates a new cache key for the DNS name, qtype and qclass +func NewKey(name string, qtype, qclass uint16) uint32 { + h := fnv.New32a() + h.Write([]byte(name)) + _ = binary.Write(h, binary.LittleEndian, qtype) + _ = binary.Write(h, binary.LittleEndian, qclass) + return h.Sum32() +} + +// Get returns the DNS message associated with key k. Get will return nil if any TTL in the answer section of the // +// message is exceeded according to time t. +func (c *Cache) Get(k uint32, t time.Time) (dns.Msg, bool) { + c.mu.RLock() + v, ok := c.entries[k] + c.mu.RUnlock() + if !ok { + return dns.Msg{}, false + } + if isExpired(v, t) { + c.mu.Lock() + delete(c.entries, k) + c.mu.Unlock() + return dns.Msg{}, false + } + return v.msg, true +} + +// Add adds given DNS message msg to the cache with creation time t. Creation time plus the TTL of the answer section +// decides when the message expires. +func (c *Cache) Add(msg *dns.Msg, t time.Time) { + if c.maxSize == 0 { + return + } + q := msg.Question[0] + k := NewKey(q.Name, q.Qtype, q.Qclass) + c.mu.Lock() + if len(c.entries) == c.maxSize && c.maxSize > 0 { + // Reached max size, delete the oldest entry + delete(c.entries, c.keys[0]) + c.keys = c.keys[1:] + } + c.entries[k] = value{*msg, t} + c.keys = append(c.keys, k) + c.mu.Unlock() +} + +func isExpired(v value, t time.Time) bool { + for _, answer := range v.msg.Answer { + if t.After(v.createdAt.Add(ttl(answer))) { + return true + } + } + return false +} + +func ttl(rr dns.RR) time.Duration { + ttlSecs := rr.Header().Ttl + return time.Duration(time.Duration(ttlSecs) * time.Second) +} diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..2b99b0f --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,116 @@ +package cache + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg { + m := dns.Msg{} + m.Id = dns.Id() + m.SetQuestion(dns.Fqdn(name), dns.TypeA) + rr := make([]dns.RR, 0, len(ipAddr)) + for _, ip := range ipAddr { + rr = append(rr, &dns.A{ + A: ip, + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, + }) + } + m.Answer = rr + return &m +} + +func date(year int, month time.Month, day int) time.Time { + return time.Date(year, month, day, 0, 0, 0, 0, time.UTC) +} + +func TestNewKey(t *testing.T) { + var tests = []struct { + name string + qtype, qclass uint16 + out uint32 + }{ + {"foo.", dns.TypeA, dns.ClassINET, 3170238979}, + {"foo.", dns.TypeAAAA, dns.ClassINET, 2108186350}, + {"foo.", dns.TypeA, dns.ClassANY, 2025815293}, + {"bar.", dns.TypeA, dns.ClassINET, 1620283204}, + } + for i, tt := range tests { + got := NewKey(tt.name, tt.qtype, tt.qclass) + if got != tt.out { + t.Errorf("#%d: NewKey(%q, %d, %d) = %d, want %d", i, tt.name, tt.qtype, tt.qclass, got, tt.out) + } + } +} + +func TestCache(t *testing.T) { + m := newA("foo.", 60, net.ParseIP("192.0.2.1")) + tt := date(2019, 1, 1) + c, err := New(100) + if err != nil { + t.Fatal(err) + } + var tests = []struct { + msg *dns.Msg + createdAt, queriedAt time.Time + ok bool + }{ + {m, tt, tt, true}, // Not expired when query time == create time + {m, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL + {m, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds + {m, tt, tt.Add(61 * time.Second), false}, // Expired + } + for i, tt := range tests { + k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass) + c.Add(tt.msg, tt.createdAt) + msg, ok := c.Get(k, tt.queriedAt) + if ok != tt.ok { + t.Errorf("#%d: Get(%d) = (%+v, %t), want (_, %t)", i, k, msg, ok, tt.ok) + } + if _, ok := c.entries[k]; ok != tt.ok { + t.Errorf("#%d: Cache[%d] = %t, want %t", i, k, ok, tt.ok) + } + } +} + +func TestCacheMaxSize(t *testing.T) { + var tests = []struct { + addCount, maxSize, size int + }{ + {1, 0, 0}, + {1, 2, 1}, + {2, 2, 2}, + {3, 2, 2}, + } + for i, tt := range tests { + c, err := New(tt.maxSize) + if err != nil { + t.Fatal(err) + } + var msgs []*dns.Msg + for i := 0; i < tt.addCount; i++ { + m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i))) + msgs = append(msgs, m) + c.Add(m, time.Time{}) + } + if got := len(c.entries); got != tt.size { + t.Errorf("#%d: len(entries) = %d, want %d", i, got, tt.size) + } + if tt.maxSize > 0 && tt.addCount > tt.maxSize && tt.maxSize == tt.size { + lastAdded := msgs[tt.addCount-1].Question[0] + lastK := NewKey(lastAdded.Name, lastAdded.Qtype, lastAdded.Qclass) + if _, ok := c.Get(lastK, time.Time{}); !ok { + t.Errorf("#%d: Get(NewKey(%q, _, _)) = (_, %t), want (_, %t)", i, lastAdded.Name, ok, !ok) + } + firstAdded := msgs[0].Question[0] + firstK := NewKey(firstAdded.Name, firstAdded.Qtype, firstAdded.Qclass) + if _, ok := c.Get(firstK, time.Time{}); ok { + t.Errorf("#%d: Get(NewKey(%q, _, _)) = (_, %t), want (_, %t)", i, firstAdded.Name, ok, !ok) + } + } + } +} |