From df3a34406ba499a636f2ede5feacae60b08f1bff Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Fri, 10 Jan 2020 22:31:16 +0100 Subject: Add Key field --- cache/cache.go | 22 ++++++++++++++++------ cache/cache_test.go | 22 +++++++++++++--------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/cache/cache.go b/cache/cache.go index 8ee6ce8..e32df8b 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -26,6 +26,7 @@ type Cache struct { // Value wraps a DNS message stored in the cache. type Value struct { + Key uint64 CreatedAt time.Time msg *dns.Msg } @@ -48,6 +49,8 @@ func (v *Value) TTL() time.Duration { return dnsutil.MinTTL(v.msg) } // Pack returns a string representation of Value v. func (v *Value) Pack() (string, error) { var sb strings.Builder + sb.WriteString(strconv.FormatUint(v.Key, 10)) + sb.WriteString(" ") sb.WriteString(strconv.FormatInt(v.CreatedAt.Unix(), 10)) sb.WriteString(" ") data, err := v.msg.Pack() @@ -61,14 +64,18 @@ func (v *Value) Pack() (string, error) { // Unpack converts a string value into a Value type. func Unpack(value string) (Value, error) { fields := strings.Fields(value) - if len(fields) < 2 { + if len(fields) < 3 { return Value{}, fmt.Errorf("invalid number of fields: %q", value) } - secs, err := strconv.ParseInt(fields[0], 10, 64) + key, err := strconv.ParseUint(fields[0], 10, 64) + if err != nil { + return Value{}, err + } + secs, err := strconv.ParseInt(fields[1], 10, 64) if err != nil { return Value{}, err } - data, err := hex.DecodeString(fields[1]) + data, err := hex.DecodeString(fields[2]) if err != nil { return Value{}, err } @@ -77,6 +84,7 @@ func Unpack(value string) (Value, error) { return Value{}, err } return Value{ + Key: key, CreatedAt: time.Unix(secs, 0), msg: msg, }, nil @@ -169,15 +177,17 @@ func (c *Cache) Set(key uint64, msg *dns.Msg) { } func (c *Cache) set(key uint64, msg *dns.Msg) bool { - if c.capacity == 0 || !canCache(msg) { + return c.setValue(key, Value{Key: key, CreatedAt: c.now(), msg: msg}) +} + +func (c *Cache) setValue(key uint64, value Value) bool { + if c.capacity == 0 || !canCache(value.msg) { return false } - now := c.now() if len(c.values) == c.capacity && c.capacity > 0 { delete(c.values, c.keys[0]) c.keys = c.keys[1:] } - value := Value{CreatedAt: now, msg: msg} c.values[key] = value c.appendKey(key) return true diff --git a/cache/cache_test.go b/cache/cache_test.go index bf4c610..fb2fe3c 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -134,13 +134,13 @@ func TestCache(t *testing.T) { ok bool value *Value }{ - {msg, now, true, &Value{CreatedAt: now, msg: msg}}, // Not expired when query time == create time - {msg, now.Add(30 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired when below TTL - {msg, now.Add(60 * time.Second), true, &Value{CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds - {msgNameError, now, true, &Value{CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached - {msg, now.Add(61 * time.Second), false, nil}, // Expired due to TTL exceeded - {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached - {msgFailure, now, false, nil}, // Non-cacheable rcode + {msg, now, true, &Value{Key: 16316346957082771326, CreatedAt: now, msg: msg}}, // Not expired when query time == create time + {msg, now.Add(30 * time.Second), true, &Value{Key: 16316346957082771326, CreatedAt: now, msg: msg}}, // Not expired when below TTL + {msg, now.Add(60 * time.Second), true, &Value{Key: 16316346957082771326, CreatedAt: now, msg: msg}}, // Not expired until TTL exceeds + {msgNameError, now, true, &Value{Key: 7258598034460334943, CreatedAt: now, msg: msgNameError}}, // NXDOMAIN is cached + {msg, now.Add(61 * time.Second), false, nil}, // Expired due to TTL exceeded + {msgWithZeroTTL, now, false, nil}, // 0 TTL is not cached + {msgFailure, now, false, nil}, // Non-cacheable rcode } for i, tt := range tests { c.now = nowFn @@ -362,6 +362,7 @@ func TestCacheEvictAndUpdate(t *testing.T) { func TestPackValue(t *testing.T) { v := Value{ + Key: 42, CreatedAt: time.Now().Truncate(time.Second), msg: newA("example.com.", 60, net.ParseIP("192.0.2.1")), } @@ -373,10 +374,13 @@ func TestPackValue(t *testing.T) { if err != nil { t.Fatal(err) } - if got, want := unpacked.CreatedAt, v.CreatedAt; !want.Equal(got) { + if got, want := unpacked.Key, v.Key; got != want { + t.Errorf("Key = %d, want %d", got, want) + } + if got, want := unpacked.CreatedAt, v.CreatedAt; !got.Equal(want) { t.Errorf("CreatedAt = %s, want %s", got, want) } - if got, want := unpacked.msg.String(), v.msg.String(); want != got { + if got, want := unpacked.msg.String(), v.msg.String(); got != want { t.Errorf("msg = %s, want %s", got, want) } } -- cgit v1.2.3