aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cache/cache.go30
-rw-r--r--cache/cache_test.go41
2 files changed, 47 insertions, 24 deletions
diff --git a/cache/cache.go b/cache/cache.go
index bf7cfae..014d3fe 100644
--- a/cache/cache.go
+++ b/cache/cache.go
@@ -3,7 +3,6 @@ package cache
import (
"encoding/binary"
"hash/fnv"
- "net"
"sync"
"time"
@@ -54,7 +53,7 @@ func (m *maintainer) run(cache *Cache) {
type Value struct {
Question string
Qtype uint16
- Answer net.IP
+ Answers []string
CreatedAt time.Time
msg *dns.Msg
}
@@ -147,7 +146,7 @@ func (c *Cache) Set(k uint32, msg *dns.Msg) {
}
c.entries[k] = &Value{
Question: question(msg),
- Answer: answer(msg),
+ Answers: answers(msg),
Qtype: qtype(msg),
CreatedAt: now,
msg: msg,
@@ -156,21 +155,24 @@ func (c *Cache) Set(k uint32, msg *dns.Msg) {
c.mu.Unlock()
}
-func qtype(m *dns.Msg) uint16 { return m.Question[0].Qtype }
+func qtype(msg *dns.Msg) uint16 { return msg.Question[0].Qtype }
-func question(m *dns.Msg) string { return m.Question[0].Name }
+func question(msg *dns.Msg) string { return msg.Question[0].Name }
-func answer(m *dns.Msg) net.IP {
- rr := m.Answer[0]
- switch v := rr.(type) {
- case *dns.A:
- return v.A
- case *dns.AAAA:
- return v.AAAA
+func answers(msg *dns.Msg) []string {
+ var answers []string
+ for _, answer := range msg.Answer {
+ switch v := answer.(type) {
+ case *dns.A:
+ answers = append(answers, v.A.String())
+ case *dns.AAAA:
+ answers = append(answers, v.AAAA.String())
+ case *dns.MX:
+ answers = append(answers, v.Mx)
+ }
}
- return net.IPv4zero
+ return answers
}
-
func (c *Cache) deleteExpired() {
c.mu.Lock()
for k, v := range c.entries {
diff --git a/cache/cache_test.go b/cache/cache_test.go
index 11936eb..f2e00e7 100644
--- a/cache/cache_test.go
+++ b/cache/cache_test.go
@@ -78,35 +78,56 @@ func awaitExpiry(t *testing.T, c *Cache, k uint32) {
}
func TestCache(t *testing.T) {
- msg := newA("foo.", 60, net.ParseIP("192.0.2.1"))
+ msg := newA("foo.", 60, net.ParseIP("192.0.2.1"), net.ParseIP("192.0.2.2"))
msgWithZeroTTL := newA("bar.", 0, net.ParseIP("192.0.2.2"))
msgFailure := newA("baz.", 60, net.ParseIP("192.0.2.2"))
msgFailure.Rcode = dns.RcodeServerFailure
- tt := date(2019, 1, 1)
+ createdAt := date(2019, 1, 1)
c := New(100, time.Duration(10*time.Millisecond))
defer handleErr(t, c.Close)
var tests = []struct {
msg *dns.Msg
createdAt, queriedAt time.Time
ok bool
+ value *Value
}{
- {msg, tt, tt, true}, // Not expired when query time == create time
- {msg, tt, tt.Add(30 * time.Second), true}, // Not expired when below TTL
- {msg, tt, tt.Add(60 * time.Second), true}, // Not expired until TTL exceeds
- {msg, tt, tt.Add(61 * time.Second), false}, // Expired
- {msgWithZeroTTL, tt, tt, false}, // 0 TTL is not cached
- {msgFailure, tt, tt, false}, // Non-cacheable rcode
+ {msg, createdAt, createdAt, true, &Value{
+ CreatedAt: createdAt,
+ Question: "foo.",
+ Qtype: 1,
+ Answers: []string{"192.0.2.1", "192.0.2.2"},
+ msg: msg},
+ }, // Not expired when query time == create time
+ {msg, createdAt, createdAt.Add(30 * time.Second), true, &Value{
+ CreatedAt: createdAt,
+ Question: "foo.",
+ Qtype: 1,
+ Answers: []string{"192.0.2.1", "192.0.2.2"},
+ msg: msg},
+ }, // Not expired when below TTL
+ {msg, createdAt, createdAt.Add(60 * time.Second), true, &Value{
+ CreatedAt: createdAt,
+ Question: "foo.",
+ Qtype: 1,
+ Answers: []string{"192.0.2.1", "192.0.2.2"},
+ msg: msg},
+ }, //, Not expired until TTL exceeds
+ {msg, createdAt, createdAt.Add(61 * time.Second), false, nil}, // Expired
+ {msgWithZeroTTL, createdAt, createdAt, false, nil}, // 0 TTL is not cached
+ {msgFailure, createdAt, createdAt, false, nil}, // Non-cacheable rcode
}
for i, tt := range tests {
c.now = func() time.Time { return tt.createdAt }
k := NewKey(tt.msg.Question[0].Name, tt.msg.Question[0].Qtype, tt.msg.Question[0].Qclass)
c.Set(k, tt.msg)
c.now = func() time.Time { return tt.queriedAt }
- msg, ok := c.Get(k)
- if ok != tt.ok {
+ if msg, ok := c.Get(k); ok != tt.ok {
t.Errorf("#%d: Get(%d) = (%+v, %t), want (_, %t)", i, k, msg, ok, tt.ok)
}
+ if v, ok := c.getValue(k); ok != tt.ok || !reflect.DeepEqual(v, tt.value) {
+ t.Errorf("#%d: getValue(%d) = (%+v, %t), want (%+v, %t)", i, k, v, ok, tt.value, tt.ok)
+ }
if !tt.ok {
awaitExpiry(t, c, k)
}