From aa4ed6f02282a7269ae6c2f8e078f34515e4a5b9 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Sun, 12 Jan 2020 11:12:39 +0100 Subject: Reject invalid parameters --- http/http.go | 32 +++++++++++++++++++++++++------- http/http_test.go | 4 ++-- 2 files changed, 27 insertions(+), 9 deletions(-) (limited to 'http') diff --git a/http/http.go b/http/http.go index 3da9b4f..26e5b25 100644 --- a/http/http.go +++ b/http/http.go @@ -2,6 +2,7 @@ package http import ( "context" + "fmt" "log" "net" "net/http" @@ -72,6 +73,13 @@ func newHTTPError(err error) *httpError { } } +func newHTTPBadRequest(err error) *httpError { + return &httpError{ + err: err, + Status: http.StatusBadRequest, + } +} + // NewServer creates a new HTTP server, serving logs from the given logger and listening on addr. func NewServer(cache *cache.Cache, logger *sql.Logger, addr string) *Server { server := &http.Server{Addr: addr} @@ -93,18 +101,24 @@ func (s *Server) handler() http.Handler { return r.handler() } -func listCountFrom(r *http.Request) int { - defaultCount := 100 +func countFrom(r *http.Request) (int, error) { param := r.URL.Query().Get("n") + if param == "" { + return 100, nil + } n, err := strconv.Atoi(param) - if err != nil { - return defaultCount + if err != nil || n < 0 { + return 0, fmt.Errorf("invalid value for parameter n: %s", param) } - return n + return n, nil } func (s *Server) cacheHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) { - cacheValues := s.cache.List(listCountFrom(r)) + count, err := countFrom(r) + if err != nil { + return nil, newHTTPBadRequest(err) + } + cacheValues := s.cache.List(count) entries := make([]entry, 0, len(cacheValues)) for _, v := range cacheValues { entries = append(entries, entry{ @@ -127,7 +141,11 @@ func (s *Server) cacheResetHandler(w http.ResponseWriter, r *http.Request) (inte } func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) { - logEntries, err := s.logger.Read(listCountFrom(r)) + count, err := countFrom(r) + if err != nil { + return nil, newHTTPBadRequest(err) + } + logEntries, err := s.logger.Read(count) if err != nil { return nil, newHTTPError(err) } diff --git a/http/http_test.go b/http/http_test.go index 2909271..a07a7c4 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -99,10 +99,10 @@ func TestRequests(t *testing.T) { }{ {http.MethodGet, "/not-found", `{"status":404,"message":"Resource not found"}`, 404}, {http.MethodGet, "/log/v1/", lr1, 200}, - {http.MethodGet, "/log/v1/?n=foo", lr1, 200}, + {http.MethodGet, "/log/v1/?n=foo", `{"status":400,"message":"invalid value for parameter n: foo"}`, 400}, {http.MethodGet, "/log/v1/?n=1", lr2, 200}, {http.MethodGet, "/cache/v1/", cr1, 200}, - {http.MethodGet, "/cache/v1/?n=foo", cr1, 200}, + {http.MethodGet, "/cache/v1/?n=foo", `{"status":400,"message":"invalid value for parameter n: foo"}`, 400}, {http.MethodGet, "/cache/v1/?n=1", cr2, 200}, {http.MethodGet, "/metric/v1/", mr1, 200}, {http.MethodDelete, "/cache/v1/", `{"message":"Cleared cache."}`, 200}, -- cgit v1.2.3