aboutsummaryrefslogtreecommitdiffstats
path: root/http
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-28 19:56:44 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-28 19:56:44 +0100
commit0b9976424c737e03a580f31255cc40bb118ee276 (patch)
tree6efcab8da465d683b1b4920f93822f366a918f6b /http
parent285d65efb619a2564ffa8dca4f15cd9942cc5276 (diff)
Extract request router
Diffstat (limited to 'http')
-rw-r--r--http/http.go47
-rw-r--r--http/http_test.go22
-rw-r--r--http/router.go80
3 files changed, 97 insertions, 52 deletions
diff --git a/http/http.go b/http/http.go
index 12d9e31..98e6616 100644
--- a/http/http.go
+++ b/http/http.go
@@ -2,7 +2,6 @@ package http
import (
"context"
- "encoding/json"
"net"
"net/http"
"strconv"
@@ -51,48 +50,10 @@ func NewServer(logger *log.Logger, cache *cache.Cache, addr string) *Server {
}
func (s *Server) handler() http.Handler {
- mux := http.NewServeMux()
- mux.Handle("/cache/v1/", appHandler(s.cacheHandler))
- mux.Handle("/log/v1/", appHandler(s.logHandler))
- mux.Handle("/", appHandler(notFoundHandler))
- return requestFilter(mux)
-}
-
-type appHandler func(http.ResponseWriter, *http.Request) (interface{}, *httpError)
-
-func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- data, e := fn(w, r)
- if e != nil { // e is *Error, not os.Error.
- if e.Message == "" {
- e.Message = e.err.Error()
- }
- out, err := json.Marshal(e)
- if err != nil {
- panic(err)
- }
- w.WriteHeader(e.Status)
- w.Write(out)
- } else if data != nil {
- out, err := json.Marshal(data)
- if err != nil {
- panic(err)
- }
- w.Write(out)
- }
-}
-
-func requestFilter(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- next.ServeHTTP(w, r)
- })
-}
-
-func notFoundHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) {
- return nil, &httpError{
- Status: http.StatusNotFound,
- Message: "Resource not found",
- }
+ r := newRouter()
+ r.route("GET", "/cache/v1/", s.cacheHandler)
+ r.route("GET", "/log/v1/", s.logHandler)
+ return r.handler()
}
func listCountFrom(r *http.Request) int {
diff --git a/http/http_test.go b/http/http_test.go
index e964e3b..bdeb213 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -39,17 +39,17 @@ func testServer() (*httptest.Server, *Server) {
return httptest.NewServer(server.handler()), &server
}
-func httpGet(url string) (string, int, error) {
+func httpGet(url string) (*http.Response, string, error) {
res, err := http.Get(url)
if err != nil {
- return "", 0, err
+ return nil, "", err
}
defer res.Body.Close()
data, err := ioutil.ReadAll(res.Body)
if err != nil {
- return "", 0, err
+ return nil, "", err
}
- return string(data), res.StatusCode, nil
+ return res, string(data), nil
}
func TestRequests(t *testing.T) {
@@ -84,23 +84,27 @@ func TestRequests(t *testing.T) {
for i, tt := range tests {
var (
- data string
- status int
- err error
+ resp *http.Response
+ data string
+ err error
)
switch tt.method {
case http.MethodGet:
- data, status, err = httpGet(httpSrv.URL + tt.url)
+ resp, data, err = httpGet(httpSrv.URL + tt.url)
default:
t.Fatalf("#%d: invalid method: %s", i, tt.method)
}
if err != nil {
t.Fatal(err)
}
- if got := status; status != tt.status {
+ if got := resp.StatusCode; got != tt.status {
t.Errorf("#%d: %s %s returned status %d, want %d", i, tt.method, tt.url, got, tt.status)
}
+ if got, want := resp.Header.Get("Content-Type"), "application/json"; got != want {
+ t.Errorf("#%d: got Content-Type %q, want %q", i, got, want)
+ }
+
got := string(data)
want := regexp.QuoteMeta(tt.response)
want = strings.ReplaceAll(want, "RFC3339", `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z`)
diff --git a/http/router.go b/http/router.go
new file mode 100644
index 0000000..e7087f7
--- /dev/null
+++ b/http/router.go
@@ -0,0 +1,80 @@
+package http
+
+import (
+ "encoding/json"
+ "net/http"
+)
+
+type router struct {
+ routes []*route
+}
+
+type route struct {
+ method string
+ path string
+ handler appHandler
+}
+
+type appHandler func(http.ResponseWriter, *http.Request) (interface{}, *httpError)
+
+func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ data, e := fn(w, r)
+ w.Header().Set("Content-Type", "application/json")
+ if e != nil { // e is *Error, not os.Error.
+ if e.Message == "" {
+ e.Message = e.err.Error()
+ }
+ out, err := json.Marshal(e)
+ if err != nil {
+ panic(err)
+ }
+ w.WriteHeader(e.Status)
+ w.Write(out)
+ } else if data != nil {
+ out, err := json.Marshal(data)
+ if err != nil {
+ panic(err)
+ }
+ w.Write(out)
+ }
+}
+
+func newRouter() *router { return &router{} }
+
+func notFoundHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) {
+ return nil, &httpError{
+ Status: http.StatusNotFound,
+ Message: "Resource not found",
+ }
+}
+
+func (r *router) route(method, path string, handler appHandler) *route {
+ route := route{
+ method: method,
+ path: path,
+ handler: handler,
+ }
+ r.routes = append(r.routes, &route)
+ return &route
+}
+
+func (r *router) handler() http.Handler {
+ return appHandler(func(w http.ResponseWriter, req *http.Request) (interface{}, *httpError) {
+ for _, route := range r.routes {
+ if route.match(req) {
+ return route.handler(w, req)
+ }
+ }
+ return notFoundHandler(w, req)
+ })
+}
+
+func (r *route) match(req *http.Request) bool {
+ if req.Method != r.method {
+ return false
+ }
+ if r.path != req.URL.Path {
+ return false
+ }
+ return true
+}