aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2021-08-09 11:00:04 +0200
committerMartin Polden <mpolden@mpolden.no>2021-08-09 12:36:23 +0200
commitdd9109069b568b0381f8ef8131451a8060571c78 (patch)
tree68ed54848e497a3c79b52d4b363cd221d96c3fd5
parent8422364b2c5f3503052eecd7ad2e634fa7f79e20 (diff)
cache: Implement custom cache
-rw-r--r--cache/cache.go70
-rw-r--r--cache/cache_test.go36
-rw-r--r--go.mod4
-rw-r--r--go.sum2
-rw-r--r--http/http.go6
5 files changed, 110 insertions, 8 deletions
diff --git a/cache/cache.go b/cache/cache.go
new file mode 100644
index 0000000..32ec054
--- /dev/null
+++ b/cache/cache.go
@@ -0,0 +1,70 @@
+package cache
+
+import (
+ "sync"
+ "time"
+)
+
+// Cache is a key-value cache that expires and evicts entries according to a TTL.
+type Cache struct {
+ entries map[string]entry
+ now func() time.Time
+ mu sync.RWMutex
+}
+
+type entry struct {
+ value interface{}
+ expiry time.Time
+}
+
+func (e *entry) isExpired(now time.Time) bool { return now.After(e.expiry) }
+
+// New creates a new cache which evicts expired entries every expiryInterval.
+func New(expiryInterval time.Duration) *Cache {
+ entries := make(map[string]entry)
+ c := &Cache{entries: entries, now: time.Now}
+ go func() {
+ select {
+ case <-time.After(expiryInterval):
+ c.evictExpired()
+ }
+ }()
+ return c
+}
+
+func (c *Cache) evictExpired() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ now := c.now()
+ for k, v := range c.entries {
+ if v.isExpired(now) {
+ delete(c.entries, k)
+ }
+ }
+}
+
+// Len returns the number of values in the cache. This includes entries that have expired, but are not yet evicted.
+func (c *Cache) Len() int {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ return len(c.entries)
+}
+
+// Get returns the cached value associated with key.
+func (c *Cache) Get(key string) (interface{}, bool) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ v, ok := c.entries[key]
+ if !ok || v.isExpired(c.now()) {
+ return nil, false
+ }
+ return v.value, true
+}
+
+// Set associates key with given value in the cache. The value is invalidated after ttl has passed.
+func (c *Cache) Set(key string, value interface{}, ttl time.Duration) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ expiry := c.now().Add(ttl)
+ c.entries[key] = entry{value: value, expiry: expiry}
+}
diff --git a/cache/cache_test.go b/cache/cache_test.go
new file mode 100644
index 0000000..50651b9
--- /dev/null
+++ b/cache/cache_test.go
@@ -0,0 +1,36 @@
+package cache
+
+import (
+ "testing"
+ "time"
+)
+
+func TestCache(t *testing.T) {
+ now := time.Now()
+ c := New(time.Second)
+ var tests = []struct {
+ key string
+ value interface{}
+ ok bool
+ ttl time.Duration
+ nowOffset time.Duration
+ }{
+ {"k1", 1, true, time.Minute, 0},
+ {"k2", 2, true, time.Minute, time.Minute},
+ {"k3", nil, false, time.Minute, -time.Second * 61},
+ {"k4", nil, false, time.Second * 5, -time.Second * 6},
+ }
+ for i, tt := range tests {
+ c.now = func() time.Time { return now.Add(tt.nowOffset) }
+ c.Set(tt.key, tt.value, tt.ttl)
+ c.now = func() time.Time { return now }
+ v, ok := c.Get(tt.key)
+ if ok != tt.ok || v != tt.value {
+ t.Errorf("#%d: Get(%q) = (%v, %t), want (%v, %t)", i, tt.key, v, ok, tt.value, tt.ok)
+ }
+ }
+ c.evictExpired()
+ if got, want := c.Len(), 2; got != want {
+ t.Errorf("Len() = %d, want %d", got, want)
+ }
+}
diff --git a/go.mod b/go.mod
index e2a7581..dbfc87a 100644
--- a/go.mod
+++ b/go.mod
@@ -1,5 +1,3 @@
module github.com/mpolden/atb
-go 1.13
-
-require github.com/pmylund/go-cache v2.1.0+incompatible
+go 1.16
diff --git a/go.sum b/go.sum
index 89fefa8..e69de29 100644
--- a/go.sum
+++ b/go.sum
@@ -1,2 +0,0 @@
-github.com/pmylund/go-cache v2.1.0+incompatible h1:n+7K51jLz6a3sCvff3BppuCAkixuDHuJ/C57Vw/XjTE=
-github.com/pmylund/go-cache v2.1.0+incompatible/go.mod h1:hmz95dGvINpbRZGsqPcd7B5xXY5+EKb5PpGhQY3NTHk=
diff --git a/http/http.go b/http/http.go
index eeeb9c2..0f3ab88 100644
--- a/http/http.go
+++ b/http/http.go
@@ -11,7 +11,7 @@ import (
"time"
"github.com/mpolden/atb/atb"
- cache "github.com/pmylund/go-cache"
+ "github.com/mpolden/atb/cache"
)
// Server represents an Server server.
@@ -86,7 +86,7 @@ func (s *Server) getDepartures(urlPrefix string, nodeID int) (Departures, bool,
return Departures{}, hit, err
}
departures.URL = fmt.Sprintf("%s/api/v1/departures/%d", urlPrefix, nodeID)
- s.cache.Set(cacheKey, departures, cache.DefaultExpiration)
+ s.cache.Set(cacheKey, departures, s.ttl.departures)
return departures, hit, nil
}
@@ -225,7 +225,7 @@ func (s *Server) DefaultHandler(w http.ResponseWriter, r *http.Request) (interfa
// New returns a new Server using client to communicate with AtB. stopTTL and departureTTL control the cache TTL bus
// stops and departures.
func New(client atb.Client, stopTTL, departureTTL time.Duration, cors bool) Server {
- cache := cache.New(departureTTL, 30*time.Second)
+ cache := cache.New(time.Minute)
return Server{
Client: client,
CORS: cors,