diff options
author | Martin Polden <mpolden@mpolden.no> | 2017-07-10 13:33:08 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2017-07-10 13:33:08 +0200 |
commit | d45a1eca2372ceeac84e3db1204696aa905ca726 (patch) | |
tree | 2e79e259a757b5008a6f01eccd3961f7b7e176a4 | |
parent | 5d76fe53d37d903d8ea821aab56272febd73660f (diff) |
Remove gorilla/mux
-rw-r--r-- | api/api.go | 68 | ||||
-rw-r--r-- | api/api_test.go | 27 | ||||
-rw-r--r-- | main.go | 3 |
3 files changed, 55 insertions, 43 deletions
@@ -6,10 +6,10 @@ import ( "log" "net/http" "net/url" + "path/filepath" "strconv" "time" - "github.com/gorilla/mux" "github.com/mpolden/atbapi/atb" cache "github.com/pmylund/go-cache" ) @@ -105,8 +105,8 @@ func (a *API) setCacheHeader(w http.ResponseWriter, hit bool) { } // BusStopsHandler is a handler for retrieving bus stops. -func (a *API) BusStopsHandler(w http.ResponseWriter, req *http.Request) (interface{}, *Error) { - busStops, hit, err := a.getBusStops(urlPrefix(req)) +func (a *API) BusStopsHandler(w http.ResponseWriter, r *http.Request) (interface{}, *Error) { + busStops, hit, err := a.getBusStops(urlPrefix(r)) if err != nil { return nil, &Error{ err: err, @@ -115,7 +115,7 @@ func (a *API) BusStopsHandler(w http.ResponseWriter, req *http.Request) (interfa } } a.setCacheHeader(w, hit) - _, geojson := req.URL.Query()["geojson"] + _, geojson := r.URL.Query()["geojson"] if geojson { return busStops.GeoJSON(), nil } @@ -123,9 +123,8 @@ func (a *API) BusStopsHandler(w http.ResponseWriter, req *http.Request) (interfa } // BusStopHandler is a handler for retrieving info about a bus stop. -func (a *API) BusStopHandler(w http.ResponseWriter, req *http.Request) (interface{}, *Error) { - vars := mux.Vars(req) - nodeID, err := strconv.Atoi(vars["nodeID"]) +func (a *API) BusStopHandler(w http.ResponseWriter, r *http.Request) (interface{}, *Error) { + nodeID, err := strconv.Atoi(filepath.Base(r.URL.Path)) if err != nil { return nil, &Error{ err: err, @@ -133,7 +132,7 @@ func (a *API) BusStopHandler(w http.ResponseWriter, req *http.Request) (interfac Message: "missing or invalid nodeID", } } - busStops, hit, err := a.getBusStops(urlPrefix(req)) + busStops, hit, err := a.getBusStops(urlPrefix(r)) if err != nil { return nil, &Error{ err: err, @@ -151,17 +150,16 @@ func (a *API) BusStopHandler(w http.ResponseWriter, req *http.Request) (interfac } } a.setCacheHeader(w, hit) - _, geojson := req.URL.Query()["geojson"] + _, geojson := r.URL.Query()["geojson"] if geojson { return busStop.GeoJSON(), nil } return busStop, nil } -// DepartureHandler is a handler for retrieving departures for a given bus stop -func (a *API) DepartureHandler(w http.ResponseWriter, req *http.Request) (interface{}, *Error) { - vars := mux.Vars(req) - nodeID, err := strconv.Atoi(vars["nodeID"]) +// DepartureHandler is a handler for retrieving departures for a given bus stop. +func (a *API) DepartureHandler(w http.ResponseWriter, r *http.Request) (interface{}, *Error) { + nodeID, err := strconv.Atoi(filepath.Base(r.URL.Path)) if err != nil { return nil, &Error{ err: err, @@ -169,7 +167,7 @@ func (a *API) DepartureHandler(w http.ResponseWriter, req *http.Request) (interf Message: "missing or invalid nodeID", } } - busStops, hit, err := a.getBusStops(urlPrefix(req)) + busStops, hit, err := a.getBusStops(urlPrefix(r)) if err != nil { return nil, &Error{ err: err, @@ -186,7 +184,7 @@ func (a *API) DepartureHandler(w http.ResponseWriter, req *http.Request) (interf Message: msg, } } - departures, hit, err := a.getDepartures(urlPrefix(req), nodeID) + departures, hit, err := a.getDepartures(urlPrefix(r), nodeID) if err != nil { return nil, &Error{ err: err, @@ -198,9 +196,9 @@ func (a *API) DepartureHandler(w http.ResponseWriter, req *http.Request) (interf return departures, nil } -// DeparturesHandler lists all known departures -func (a *API) DeparturesHandler(w http.ResponseWriter, req *http.Request) (interface{}, *Error) { - busStops, hit, err := a.getBusStops(urlPrefix(req)) +// DeparturesHandler lists all known departures. +func (a *API) DeparturesHandler(w http.ResponseWriter, r *http.Request) (interface{}, *Error) { + busStops, hit, err := a.getBusStops(urlPrefix(r)) if err != nil { return nil, &Error{ err: err, @@ -214,14 +212,17 @@ func (a *API) DeparturesHandler(w http.ResponseWriter, req *http.Request) (inter } urls.URLs = make([]string, len(busStops.Stops)) for i, stop := range busStops.Stops { - urls.URLs[i] = fmt.Sprintf("%s/api/v1/departures/%d", urlPrefix(req), stop.NodeID) + urls.URLs[i] = fmt.Sprintf("%s/api/v1/departures/%d", urlPrefix(r), stop.NodeID) } return urls, nil } -// RootHandler lists known URLs -func (a *API) RootHandler(w http.ResponseWriter, req *http.Request) (interface{}, *Error) { - prefix := urlPrefix(req) +// RootHandler lists known URLs. +func (a *API) RootHandler(w http.ResponseWriter, r *http.Request) (interface{}, *Error) { + if r.URL.Path != "/" { + return a.NotFoundHandler(w, r) + } + prefix := urlPrefix(r) busStopsURL := fmt.Sprintf("%s/api/v1/busstops", prefix) departuresURL := fmt.Sprintf("%s/api/v1/departures", prefix) return struct { @@ -232,7 +233,7 @@ func (a *API) RootHandler(w http.ResponseWriter, req *http.Request) (interface{} } // NotFoundHandler handles requests to invalid routes. -func (a *API) NotFoundHandler(w http.ResponseWriter, req *http.Request) (interface{}, *Error) { +func (a *API) NotFoundHandler(w http.ResponseWriter, r *http.Request) (interface{}, *Error) { return nil, &Error{ err: nil, Status: http.StatusNotFound, @@ -291,16 +292,13 @@ func requestFilter(next http.Handler, cors bool) http.Handler { }) } -// ListenAndServe listens on the TCP network address addr and starts serving the -// API. -func (a *API) ListenAndServe(addr string) error { - r := mux.NewRouter() - r.Handle("/api/v1/busstops", appHandler(a.BusStopsHandler)) - r.Handle("/api/v1/busstops/{nodeID:[0-9]+}", appHandler(a.BusStopHandler)) - r.Handle("/api/v1/departures", appHandler(a.DeparturesHandler)) - r.Handle("/api/v1/departures/{nodeID:[0-9]+}", appHandler(a.DepartureHandler)) - r.Handle("/", appHandler(a.RootHandler)) - r.NotFoundHandler = appHandler(a.NotFoundHandler) - http.Handle("/", requestFilter(r, a.CORS)) - return http.ListenAndServe(addr, nil) +// Handler returns a root handler for the API. +func (a *API) Handler() http.Handler { + mux := http.NewServeMux() + mux.Handle("/api/v1/busstops", appHandler(a.BusStopsHandler)) + mux.Handle("/api/v1/busstops/", appHandler(a.BusStopHandler)) + mux.Handle("/api/v1/departures", appHandler(a.DeparturesHandler)) + mux.Handle("/api/v1/departures/", appHandler(a.DepartureHandler)) + mux.Handle("/", appHandler(a.RootHandler)) + return requestFilter(mux, a.CORS) } diff --git a/api/api_test.go b/api/api_test.go index e8a875d..3a49577 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -2,26 +2,39 @@ package api import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "time" "github.com/mpolden/atbapi/atb" ) -func newTestServer(path string, body string) *httptest.Server { +func atbServer() *httptest.Server { handler := func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + panic(err) + } + xml := string(b) w.Header().Set("Content-Type", "application/soap+xml; charset=utf-8") - fmt.Fprint(w, body) + if strings.Contains(xml, "GetBusStopsList") { + fmt.Fprint(w, busStopsResponse) + } else if strings.Contains(xml, "getUserRealTimeForecastByStop") { + fmt.Fprint(w, forecastResponse) + } else { + panic("unknown request body: " + xml) + } } mux := http.NewServeMux() - mux.HandleFunc(path, handler) + mux.HandleFunc("/", handler) return httptest.NewServer(mux) } func TestGetBusStops(t *testing.T) { - server := newTestServer("/", busStopsResponse) + server := atbServer() defer server.Close() atb := atb.Client{URL: server.URL} api := New(atb, 168*time.Hour, 1*time.Minute, false) @@ -46,7 +59,7 @@ func TestGetBusStops(t *testing.T) { } func TestGetBusStopsCache(t *testing.T) { - server := newTestServer("/", busStopsResponse) + server := atbServer() defer server.Close() atb := atb.Client{URL: server.URL} api := New(atb, 168*time.Hour, 1*time.Minute, false) @@ -67,7 +80,7 @@ func TestGetBusStopsCache(t *testing.T) { } func TestGetDepartures(t *testing.T) { - server := newTestServer("/", forecastResponse) + server := atbServer() defer server.Close() atb := atb.Client{URL: server.URL} api := New(atb, 168*time.Hour, 1*time.Minute, false) @@ -89,7 +102,7 @@ func TestGetDepartures(t *testing.T) { } func TestGetDeparturesCache(t *testing.T) { - server := newTestServer("/", forecastResponse) + server := atbServer() defer server.Close() atb := atb.Client{URL: server.URL} api := New(atb, 168*time.Hour, 1*time.Minute, false) @@ -2,6 +2,7 @@ package main import ( "log" + "net/http" "os" "time" @@ -31,7 +32,7 @@ func main() { api := api.New(client, opts.CacheStops, opts.CacheDepartures, opts.CORS) log.Printf("Listening on %s", opts.Listen) - if err := api.ListenAndServe(opts.Listen); err != nil { + if err := http.ListenAndServe(opts.Listen, api.Handler()); err != nil { log.Fatal(err) } } |