diff options
-rw-r--r-- | cache/cache.go | 30 | ||||
-rw-r--r-- | cache/cache_test.go | 41 |
2 files changed, 47 insertions, 24 deletions
diff --git a/cache/cache.go b/cache/cache.go index bf7cfae..014d3fe 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -3,7 +3,6 @@ package cache import ( "encoding/binary" "hash/fnv" - "net" "sync" "time" @@ -54,7 +53,7 @@ func (m *maintainer) run(cache *Cache) { type Value struct { Question string Qtype uint16 - Answer net.IP + Answers []string CreatedAt time.Time msg *dns.Msg } @@ -147,7 +146,7 @@ func (c *Cache) Set(k uint32, msg *dns.Msg) { } c.entries[k] = &Value{ Question: question(msg), - Answer: answer(msg), + Answers: answers(msg), Qtype: qtype(msg), CreatedAt: now, msg: msg, @@ -156,21 +155,24 @@ func (c *Cache) Set(k uint32, msg *dns.Msg) { c.mu.Unlock() } -func qtype(m *dns.Msg) uint16 { return m.Question[0].Qtype } +func qtype(msg *dns.Msg) uint16 { return msg.Question[0].Qtype } -func question(m *dns.Msg) string { return m.Question[0].Name } +func question(msg *dns.Msg) string { return msg.Question[0].Name } -func answer(m *dns.Msg) net.IP { - rr := m.Answer[0] - switch v := rr.(type) { - case *dns.A: - return v.A - case *dns.AAAA: - return v.AAAA +func answers(msg *dns.Msg) []string { + var answers []string + for _, answer := range msg.Answer { + switch v := answer.(type) { + case *dns.A: + answers = append(answers, v.A.String()) + case *dns.AAAA: + answers = append(answers, v.AAAA.String()) + case *dns.MX: + answers = append(answers, v.Mx) + } } - return net.IPv4zero + return answers } - func (c *Cache) deleteExpired() { c.mu.Lock() for k, v := range c.entries { diff --git a/cache/cache_test.go b/cache/cache_test.go index 11936eb..f2e00e7 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -78,35 +78,56 @@ func awaitExpiry(t *testing.T, c *Cache, k uint32) { } func TestCache(t *testing.T) { - msg := newA("foo.", 60, net.ParseIP("192.0.2.1")) + msg := newA("foo.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2")) msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2")) msgFailure := newA("baz.", 60, net.ParseIP("192.0.2.2")) msgFailure.Rcode = dns.RcodeServerFailure - tt := date(2019, 1, 1) + createdAt := date(2019, 1, 1) c := New(100, time.Duration(10*time.Millisecond)) defer handleErr(t, c.Close) var tests = []struct { msg *dns.Msg createdAt, queriedAt time.Time ok bool + value *Value }{ - {msg, tt, tt, true}, // Not expired when query time == create time - {msg, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL - {msg, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds - {msg, tt, tt.Add(61 * time.Second), false}, // Expired - {msgWithZeroTTL, tt, tt, false}, // 0 TTL is not cached - {msgFailure, tt, tt, false}, // Non-cacheable rcode + {msg, createdAt, createdAt, true, &Value{ + CreatedAt: createdAt, + Question: "foo.", + Qtype: 1, + Answers: []string{"192.0.2.1", "192.0.2.2"}, + msg: msg}, + }, // Not expired when query time == create time + {msg, createdAt, createdAt.Add(30 * time.Second), true, &Value{ + CreatedAt: createdAt, + Question: "foo.", + Qtype: 1, + Answers: []string{"192.0.2.1", "192.0.2.2"}, + msg: msg}, + }, // Not expired when below TTL + {msg, createdAt, createdAt.Add(60 * time.Second), true, &Value{ + CreatedAt: createdAt, + Question: "foo.", + Qtype: 1, + Answers: []string{"192.0.2.1", "192.0.2.2"}, + msg: msg}, + }, //, Not expired until TTL exceeds + {msg, createdAt, createdAt.Add(61 * time.Second), false, nil}, // Expired + {msgWithZeroTTL, createdAt, createdAt, false, nil}, // 0 TTL is not cached + {msgFailure, createdAt, createdAt, false, nil}, // Non-cacheable rcode } for i, tt := range tests { c.now = func() time.Time { return tt.createdAt } k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass) c.Set(k, tt.msg) c.now = func() time.Time { return tt.queriedAt } - msg, ok := c.Get(k) - if ok != tt.ok { + if msg, ok := c.Get(k); ok != tt.ok { t.Errorf("#%d: Get(%d) = (%+v, %t), want (_, %t)", i, k, msg, ok, tt.ok) } + if v, ok := c.getValue(k); ok != tt.ok || !reflect.DeepEqual(v, tt.value) { + t.Errorf("#%d: getValue(%d) = (%+v, %t), want (%+v, %t)", i, k, v, ok, tt.value, tt.ok) + } if !tt.ok { awaitExpiry(t, c, k) } |