aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2020-09-11 21:16:43 +0200
committerMartin Polden <mpolden@mpolden.no>2020-09-12 10:11:38 +0200
commitb7ed23345297d90c5bf3edca963de608b91d952e (patch)
treefe192e884b76217f39a83597052d89807a8877d2
parentceaff84709dfd014ec73a6e7bb0dd6ef48d2b19e (diff)
http, cache: Support cache resizing
-rw-r--r--http/cache.go28
-rw-r--r--http/cache_test.go20
-rw-r--r--http/http.go26
-rw-r--r--http/http_test.go33
4 files changed, 100 insertions, 7 deletions
diff --git a/http/cache.go b/http/cache.go
index 97aac30..6d61b20 100644
--- a/http/cache.go
+++ b/http/cache.go
@@ -44,13 +44,17 @@ func (c *Cache) Set(ip net.IP, resp Response) {
k := key(ip)
c.mu.Lock()
defer c.mu.Unlock()
- if len(c.entries) == c.capacity {
- // At capacity. Remove the oldest entry
- oldest := c.values.Front()
- oldestValue := oldest.Value.(Response)
- oldestKey := key(oldestValue.IP)
- delete(c.entries, oldestKey)
- c.values.Remove(oldest)
+ minEvictions := len(c.entries) - c.capacity + 1
+ if minEvictions > 0 { // At or above capacity. Shrink the cache
+ evicted := 0
+ for el := c.values.Front(); el != nil && evicted < minEvictions; {
+ value := el.Value.(Response)
+ delete(c.entries, key(value.IP))
+ next := el.Next()
+ c.values.Remove(el)
+ el = next
+ evicted++
+ }
}
current, ok := c.entries[k]
if ok {
@@ -70,6 +74,16 @@ func (c *Cache) Get(ip net.IP) (Response, bool) {
return r.Value.(Response), true
}
+func (c *Cache) Resize(capacity int) error {
+ if capacity < 0 {
+ return fmt.Errorf("invalid capacity: %d\n", capacity)
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.capacity = capacity
+ return nil
+}
+
func (c *Cache) Stats() CacheStats {
c.mu.RLock()
defer c.mu.RUnlock()
diff --git a/http/cache_test.go b/http/cache_test.go
index 2c6d4ea..0867958 100644
--- a/http/cache_test.go
+++ b/http/cache_test.go
@@ -54,3 +54,23 @@ func TestCacheDuplicate(t *testing.T) {
t.Errorf("want %d values, got %d", want, got)
}
}
+
+func TestCacheResize(t *testing.T) {
+ c := NewCache(10)
+ for i := 1; i <= 10; i++ {
+ ip := net.ParseIP(fmt.Sprintf("192.0.2.%d", i))
+ r := Response{IP: ip}
+ c.Set(ip, r)
+ }
+ if got, want := len(c.entries), 10; got != want {
+ t.Errorf("want %d entries, got %d", want, got)
+ }
+ if err := c.Resize(5); err != nil {
+ t.Fatal(err)
+ }
+ r := Response{IP: net.ParseIP("192.0.2.42")}
+ c.Set(r.IP, r)
+ if got, want := len(c.entries), 5; got != want {
+ t.Errorf("want %d entries, got %d", want, got)
+ }
+}
diff --git a/http/http.go b/http/http.go
index 940684f..29295b5 100644
--- a/http/http.go
+++ b/http/http.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"html/template"
+ "io/ioutil"
"path/filepath"
"strings"
@@ -271,6 +272,30 @@ func (s *Server) PortHandler(w http.ResponseWriter, r *http.Request) *appError {
return nil
}
+func (s *Server) cacheResizeHandler(w http.ResponseWriter, r *http.Request) *appError {
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return badRequest(err).WithMessage(err.Error()).AsJSON()
+ }
+ capacity, err := strconv.Atoi(string(body))
+ if err != nil {
+ return badRequest(err).WithMessage(err.Error()).AsJSON()
+ }
+ if err := s.cache.Resize(capacity); err != nil {
+ return badRequest(err).WithMessage(err.Error()).AsJSON()
+ }
+ data := struct {
+ Message string `json:"message"`
+ }{fmt.Sprintf("Changed cache capacity to %d.", capacity)}
+ b, err := json.Marshal(data)
+ if err != nil {
+ return internalServerError(err).AsJSON()
+ }
+ w.Header().Set("Content-Type", jsonMediaType)
+ w.Write(b)
+ return nil
+}
+
func (s *Server) cacheHandler(w http.ResponseWriter, r *http.Request) *appError {
cacheStats := s.cache.Stats()
var data = struct {
@@ -409,6 +434,7 @@ func (s *Server) Handler() http.Handler {
// Profiling
if s.profile {
+ r.Route("POST", "/debug/cache/resize", s.cacheResizeHandler)
r.Route("GET", "/debug/cache/", s.cacheHandler)
r.Route("GET", "/debug/pprof/cmdline", wrapHandlerFunc(pprof.Cmdline))
r.Route("GET", "/debug/pprof/profile", wrapHandlerFunc(pprof.Profile))
diff --git a/http/http_test.go b/http/http_test.go
index f081510..b7ee568 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "strings"
"testing"
"github.com/mpolden/echoip/iputil/geo"
@@ -56,6 +57,23 @@ func httpGet(url string, acceptMediaType string, userAgent string) (string, int,
return string(data), res.StatusCode, nil
}
+func httpPost(url, body string) (*http.Response, string, error) {
+ r, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body))
+ if err != nil {
+ return nil, "", err
+ }
+ res, err := http.DefaultClient.Do(r)
+ if err != nil {
+ return nil, "", err
+ }
+ defer res.Body.Close()
+ data, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return nil, "", err
+ }
+ return res, string(data), nil
+}
+
func TestCLIHandlers(t *testing.T) {
log.SetOutput(ioutil.Discard)
s := httptest.NewServer(testServer().Handler())
@@ -175,6 +193,21 @@ func TestCacheHandler(t *testing.T) {
}
}
+func TestCacheResizeHandler(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ srv := testServer()
+ srv.profile = true
+ s := httptest.NewServer(srv.Handler())
+ _, got, err := httpPost(s.URL+"/debug/cache/resize", "10")
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := `{"message":"Changed cache capacity to 10."}`
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+}
+
func TestIPFromRequest(t *testing.T) {
var tests = []struct {
remoteAddr string