aboutsummaryrefslogtreecommitdiffstats
path: root/cache
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-06-09 20:37:34 +0200
committerMartin Polden <mpolden@mpolden.no>2019-06-09 20:52:29 +0200
commit419c54374f6b10bd10e9a9b751e76c36a5e2a804 (patch)
tree43506f882cd27d675f63e8d1bb2c52f74f6c3685 /cache
parentd7db6bf84a684fc3027b5f016aacefebe8775370 (diff)
Implement cache
Diffstat (limited to 'cache')
-rw-r--r--cache/cache.go95
-rw-r--r--cache/cache_test.go116
2 files changed, 211 insertions, 0 deletions
diff --git a/cache/cache.go b/cache/cache.go
new file mode 100644
index 0000000..f2fe82d
--- /dev/null
+++ b/cache/cache.go
@@ -0,0 +1,95 @@
+package cache
+
+import (
+ "encoding/binary"
+ "fmt"
+ "hash/fnv"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Cache represents a cache of DNS entries
+type Cache struct {
+ maxSize int
+ mu sync.RWMutex
+ entries map[uint32]value
+ keys []uint32
+}
+
+type value struct {
+ msg dns.Msg
+ createdAt time.Time
+}
+
+// New creates a new cache with the given maximum size. Adding a key to a cache of this size removes the oldest key.
+func New(maxSize int) (*Cache, error) {
+ if maxSize < 0 {
+ return nil, fmt.Errorf("invalid cache size: %d", maxSize)
+ }
+ return &Cache{
+ maxSize: maxSize,
+ entries: make(map[uint32]value),
+ }, nil
+}
+
+// NewKey creates a new cache key for the DNS name, qtype and qclass
+func NewKey(name string, qtype, qclass uint16) uint32 {
+ h := fnv.New32a()
+ h.Write([]byte(name))
+ _ = binary.Write(h, binary.LittleEndian, qtype)
+ _ = binary.Write(h, binary.LittleEndian, qclass)
+ return h.Sum32()
+}
+
+// Get returns the DNS message associated with key k. Get will return nil if any TTL in the answer section of the //
+// message is exceeded according to time t.
+func (c *Cache) Get(k uint32, t time.Time) (dns.Msg, bool) {
+ c.mu.RLock()
+ v, ok := c.entries[k]
+ c.mu.RUnlock()
+ if !ok {
+ return dns.Msg{}, false
+ }
+ if isExpired(v, t) {
+ c.mu.Lock()
+ delete(c.entries, k)
+ c.mu.Unlock()
+ return dns.Msg{}, false
+ }
+ return v.msg, true
+}
+
+// Add adds given DNS message msg to the cache with creation time t. Creation time plus the TTL of the answer section
+// decides when the message expires.
+func (c *Cache) Add(msg *dns.Msg, t time.Time) {
+ if c.maxSize == 0 {
+ return
+ }
+ q := msg.Question[0]
+ k := NewKey(q.Name, q.Qtype, q.Qclass)
+ c.mu.Lock()
+ if len(c.entries) == c.maxSize && c.maxSize > 0 {
+ // Reached max size, delete the oldest entry
+ delete(c.entries, c.keys[0])
+ c.keys = c.keys[1:]
+ }
+ c.entries[k] = value{*msg, t}
+ c.keys = append(c.keys, k)
+ c.mu.Unlock()
+}
+
+func isExpired(v value, t time.Time) bool {
+ for _, answer := range v.msg.Answer {
+ if t.After(v.createdAt.Add(ttl(answer))) {
+ return true
+ }
+ }
+ return false
+}
+
+func ttl(rr dns.RR) time.Duration {
+ ttlSecs := rr.Header().Ttl
+ return time.Duration(time.Duration(ttlSecs) * time.Second)
+}
diff --git a/cache/cache_test.go b/cache/cache_test.go
new file mode 100644
index 0000000..2b99b0f
--- /dev/null
+++ b/cache/cache_test.go
@@ -0,0 +1,116 @@
+package cache
+
+import (
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
+ m := dns.Msg{}
+ m.Id = dns.Id()
+ m.SetQuestion(dns.Fqdn(name), dns.TypeA)
+ rr := make([]dns.RR, 0, len(ipAddr))
+ for _, ip := range ipAddr {
+ rr = append(rr, &dns.A{
+ A: ip,
+ Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
+ })
+ }
+ m.Answer = rr
+ return &m
+}
+
+func date(year int, month time.Month, day int) time.Time {
+ return time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
+}
+
+func TestNewKey(t *testing.T) {
+ var tests = []struct {
+ name string
+ qtype, qclass uint16
+ out uint32
+ }{
+ {"foo.", dns.TypeA, dns.ClassINET, 3170238979},
+ {"foo.", dns.TypeAAAA, dns.ClassINET, 2108186350},
+ {"foo.", dns.TypeA, dns.ClassANY, 2025815293},
+ {"bar.", dns.TypeA, dns.ClassINET, 1620283204},
+ }
+ for i, tt := range tests {
+ got := NewKey(tt.name, tt.qtype, tt.qclass)
+ if got != tt.out {
+ t.Errorf("#%d: NewKey(%q, %d, %d) = %d, want %d", i, tt.name, tt.qtype, tt.qclass, got, tt.out)
+ }
+ }
+}
+
+func TestCache(t *testing.T) {
+ m := newA("foo.", 60, net.ParseIP("192.0.2.1"))
+ tt := date(2019, 1, 1)
+ c, err := New(100)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var tests = []struct {
+ msg *dns.Msg
+ createdAt, queriedAt time.Time
+ ok bool
+ }{
+ {m, tt, tt, true}, // Not expired when query time == create time
+ {m, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL
+ {m, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds
+ {m, tt, tt.Add(61 * time.Second), false}, // Expired
+ }
+ for i, tt := range tests {
+ k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass)
+ c.Add(tt.msg, tt.createdAt)
+ msg, ok := c.Get(k, tt.queriedAt)
+ if ok != tt.ok {
+ t.Errorf("#%d: Get(%d) = (%+v, %t), want (_, %t)", i, k, msg, ok, tt.ok)
+ }
+ if _, ok := c.entries[k]; ok != tt.ok {
+ t.Errorf("#%d: Cache[%d] = %t, want %t", i, k, ok, tt.ok)
+ }
+ }
+}
+
+func TestCacheMaxSize(t *testing.T) {
+ var tests = []struct {
+ addCount, maxSize, size int
+ }{
+ {1, 0, 0},
+ {1, 2, 1},
+ {2, 2, 2},
+ {3, 2, 2},
+ }
+ for i, tt := range tests {
+ c, err := New(tt.maxSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var msgs []*dns.Msg
+ for i := 0; i < tt.addCount; i++ {
+ m := newA(fmt.Sprintf("r%d", i), 60, net.ParseIP(fmt.Sprintf("192.0.2.%d", i)))
+ msgs = append(msgs, m)
+ c.Add(m, time.Time{})
+ }
+ if got := len(c.entries); got != tt.size {
+ t.Errorf("#%d: len(entries) = %d, want %d", i, got, tt.size)
+ }
+ if tt.maxSize > 0 && tt.addCount > tt.maxSize && tt.maxSize == tt.size {
+ lastAdded := msgs[tt.addCount-1].Question[0]
+ lastK := NewKey(lastAdded.Name, lastAdded.Qtype, lastAdded.Qclass)
+ if _, ok := c.Get(lastK, time.Time{}); !ok {
+ t.Errorf("#%d: Get(NewKey(%q, _, _)) = (_, %t), want (_, %t)", i, lastAdded.Name, ok, !ok)
+ }
+ firstAdded := msgs[0].Question[0]
+ firstK := NewKey(firstAdded.Name, firstAdded.Qtype, firstAdded.Qclass)
+ if _, ok := c.Get(firstK, time.Time{}); ok {
+ t.Errorf("#%d: Get(NewKey(%q, _, _)) = (_, %t), want (_, %t)", i, firstAdded.Name, ok, !ok)
+ }
+ }
+ }
+}