diff options
author | Martin Polden <mpolden@mpolden.no> | 2021-08-09 11:00:04 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2021-08-09 12:36:23 +0200 |
commit | dd9109069b568b0381f8ef8131451a8060571c78 (patch) | |
tree | 68ed54848e497a3c79b52d4b363cd221d96c3fd5 | |
parent | 8422364b2c5f3503052eecd7ad2e634fa7f79e20 (diff) |
cache: Implement custom cache
-rw-r--r-- | cache/cache.go | 70 | ||||
-rw-r--r-- | cache/cache_test.go | 36 | ||||
-rw-r--r-- | go.mod | 4 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | http/http.go | 6 |
5 files changed, 110 insertions, 8 deletions
diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..32ec054 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,70 @@ +package cache + +import ( + "sync" + "time" +) + +// Cache is a key-value cache that expires and evicts entries according to a TTL. +type Cache struct { + entries map[string]entry + now func() time.Time + mu sync.RWMutex +} + +type entry struct { + value interface{} + expiry time.Time +} + +func (e *entry) isExpired(now time.Time) bool { return now.After(e.expiry) } + +// New creates a new cache which evicts expired entries every expiryInterval. +func New(expiryInterval time.Duration) *Cache { + entries := make(map[string]entry) + c := &Cache{entries: entries, now: time.Now} + go func() { + select { + case <-time.After(expiryInterval): + c.evictExpired() + } + }() + return c +} + +func (c *Cache) evictExpired() { + c.mu.Lock() + defer c.mu.Unlock() + now := c.now() + for k, v := range c.entries { + if v.isExpired(now) { + delete(c.entries, k) + } + } +} + +// Len returns the number of values in the cache. This includes entries that have expired, but are not yet evicted. +func (c *Cache) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return len(c.entries) +} + +// Get returns the cached value associated with key. +func (c *Cache) Get(key string) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + v, ok := c.entries[key] + if !ok || v.isExpired(c.now()) { + return nil, false + } + return v.value, true +} + +// Set associates key with given value in the cache. The value is invalidated after ttl has passed. +func (c *Cache) Set(key string, value interface{}, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + expiry := c.now().Add(ttl) + c.entries[key] = entry{value: value, expiry: expiry} +} diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..50651b9 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,36 @@ +package cache + +import ( + "testing" + "time" +) + +func TestCache(t *testing.T) { + now := time.Now() + c := New(time.Second) + var tests = []struct { + key string + value interface{} + ok bool + ttl time.Duration + nowOffset time.Duration + }{ + {"k1", 1, true, time.Minute, 0}, + {"k2", 2, true, time.Minute, time.Minute}, + {"k3", nil, false, time.Minute, -time.Second * 61}, + {"k4", nil, false, time.Second * 5, -time.Second * 6}, + } + for i, tt := range tests { + c.now = func() time.Time { return now.Add(tt.nowOffset) } + c.Set(tt.key, tt.value, tt.ttl) + c.now = func() time.Time { return now } + v, ok := c.Get(tt.key) + if ok != tt.ok || v != tt.value { + t.Errorf("#%d: Get(%q) = (%v, %t), want (%v, %t)", i, tt.key, v, ok, tt.value, tt.ok) + } + } + c.evictExpired() + if got, want := c.Len(), 2; got != want { + t.Errorf("Len() = %d, want %d", got, want) + } +} @@ -1,5 +1,3 @@ module github.com/mpolden/atb -go 1.13 - -require github.com/pmylund/go-cache v2.1.0+incompatible +go 1.16 @@ -1,2 +0,0 @@ -github.com/pmylund/go-cache v2.1.0+incompatible h1:n+7K51jLz6a3sCvff3BppuCAkixuDHuJ/C57Vw/XjTE= -github.com/pmylund/go-cache v2.1.0+incompatible/go.mod h1:hmz95dGvINpbRZGsqPcd7B5xXY5+EKb5PpGhQY3NTHk= diff --git a/http/http.go b/http/http.go index eeeb9c2..0f3ab88 100644 --- a/http/http.go +++ b/http/http.go @@ -11,7 +11,7 @@ import ( "time" "github.com/mpolden/atb/atb" - cache "github.com/pmylund/go-cache" + "github.com/mpolden/atb/cache" ) // Server represents an Server server. @@ -86,7 +86,7 @@ func (s *Server) getDepartures(urlPrefix string, nodeID int) (Departures, bool, return Departures{}, hit, err } departures.URL = fmt.Sprintf("%s/api/v1/departures/%d", urlPrefix, nodeID) - s.cache.Set(cacheKey, departures, cache.DefaultExpiration) + s.cache.Set(cacheKey, departures, s.ttl.departures) return departures, hit, nil } @@ -225,7 +225,7 @@ func (s *Server) DefaultHandler(w http.ResponseWriter, r *http.Request) (interfa // New returns a new Server using client to communicate with AtB. stopTTL and departureTTL control the cache TTL bus // stops and departures. func New(client atb.Client, stopTTL, departureTTL time.Duration, cors bool) Server { - cache := cache.New(departureTTL, 30*time.Second) + cache := cache.New(time.Minute) return Server{ Client: client, CORS: cors, |