aboutsummaryrefslogtreecommitdiffstats
path: root/http
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2019-12-25 14:29:05 +0100
committerMartin Polden <mpolden@mpolden.no>2019-12-25 14:29:05 +0100
commita923896c997b48bb6c39600b3704b532397c4628 (patch)
treec676c9364b664c4ec8c5fbf4356abbdcb1bc3448 /http
parent5df739087a832a0471ae201bf31166ac211e337d (diff)
Implement REST API for inspecting logs
Diffstat (limited to 'http')
-rw-r--r--http/http.go128
-rw-r--r--http/http_test.go82
2 files changed, 210 insertions, 0 deletions
diff --git a/http/http.go b/http/http.go
new file mode 100644
index 0000000..2b0c5cc
--- /dev/null
+++ b/http/http.go
@@ -0,0 +1,128 @@
+package http
+
+import (
+ "context"
+ "encoding/json"
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/mpolden/zdns/dns"
+ "github.com/mpolden/zdns/log"
+)
+
+// A Server defines paramaters for running an HTTP server. The HTTP server serves an API for inspecting cache contents
+// and request log.
+type Server struct {
+ server *http.Server
+ logger *log.Logger
+}
+
+type logEntry struct {
+ Time string `json:"time"`
+ RemoteAddr net.IP `json:"remote_addr"`
+ Qtype string `json:"type"`
+ Question string `json:"question"`
+ Answer string `json:"answer"`
+}
+
+type httpError struct {
+ err error
+ Status int `json:"status"`
+ Message string `json:"message"`
+}
+
+// NewServer creates a new HTTP server, serving logs from the given logger and listening on listenAddr.
+func NewServer(logger *log.Logger, listenAddr string) *Server {
+ server := &http.Server{Addr: listenAddr}
+ s := &Server{logger: logger, server: server}
+ s.server.Handler = s.handler()
+ return s
+}
+
+func (s *Server) handler() http.Handler {
+ mux := http.NewServeMux()
+ mux.Handle("/log/v1/", appHandler(s.logHandler))
+ //mux.Handle("/cache/v1")
+ mux.Handle("/", appHandler(notFoundHandler))
+ return requestFilter(mux)
+}
+
+type appHandler func(http.ResponseWriter, *http.Request) (interface{}, *httpError)
+
+func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ data, e := fn(w, r)
+ if e != nil { // e is *Error, not os.Error.
+ if e.Message == "" {
+ e.Message = e.err.Error()
+ }
+ out, err := json.Marshal(e)
+ if err != nil {
+ panic(err)
+ }
+ w.WriteHeader(e.Status)
+ w.Write(out)
+ } else if data != nil {
+ out, err := json.Marshal(data)
+ if err != nil {
+ panic(err)
+ }
+ w.Write(out)
+ }
+}
+
+func requestFilter(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ next.ServeHTTP(w, r)
+ })
+}
+
+func notFoundHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) {
+ return nil, &httpError{
+ Status: http.StatusNotFound,
+ Message: "Resource not found",
+ }
+}
+
+func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) {
+ logEntries, err := s.logger.Get(100)
+ if err != nil {
+ return nil, &httpError{
+ err: err,
+ Status: http.StatusInternalServerError,
+ }
+ }
+ entries := make([]logEntry, 0, len(logEntries))
+ for _, entry := range logEntries {
+ dnsType := ""
+ switch entry.Qtype {
+ case dns.TypeA:
+ dnsType = "A"
+ case dns.TypeAAAA:
+ dnsType = "AAAA"
+ }
+ e := logEntry{
+ Time: entry.Time.UTC().Format(time.RFC3339),
+ RemoteAddr: entry.RemoteAddr,
+ Qtype: dnsType,
+ Question: entry.Question,
+ Answer: entry.Answer,
+ }
+ entries = append(entries, e)
+ }
+ return entries, nil
+}
+
+// Close shuts down the HTTP server.
+func (s *Server) Close() error { return s.server.Shutdown(context.Background()) }
+
+// ListenAndServe starts the HTTP server listening on the configured address.
+func (s *Server) ListenAndServe() error {
+ s.logger.Printf("http server listening on http://%s", s.server.Addr)
+ err := s.server.ListenAndServe()
+ if err == http.ErrServerClosed {
+ return nil // Do not treat server closing as an error
+ }
+ return err
+}
diff --git a/http/http_test.go b/http/http_test.go
new file mode 100644
index 0000000..3e257f3
--- /dev/null
+++ b/http/http_test.go
@@ -0,0 +1,82 @@
+package http
+
+import (
+ "io/ioutil"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/mpolden/zdns/log"
+)
+
+func testServer() (*httptest.Server, *log.Logger) {
+ logger, err := log.New(ioutil.Discard, "", log.RecordOptions{
+ Database: ":memory:",
+ })
+ logger.Now = func() time.Time { return time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) }
+ if err != nil {
+ panic(err)
+ }
+ server := Server{logger: logger}
+ return httptest.NewServer(server.handler()), logger
+}
+
+func httpGet(url string) (string, int, error) {
+ res, err := http.Get(url)
+ if err != nil {
+ return "", 0, err
+ }
+ defer res.Body.Close()
+ data, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return "", 0, err
+ }
+ return string(data), res.StatusCode, nil
+}
+
+func TestRequests(t *testing.T) {
+ server, logger := testServer()
+ defer server.Close()
+ logger.Record(net.IPv4(127, 0, 0, 42), 1, "example.com.", "192.0.2.100")
+ logger.Record(net.IPv4(127, 0, 0, 254), 28, "example.com.", "2001:db8::1")
+
+ var logResponse = "[{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.42\",\"type\":\"A\",\"question\":\"example.com.\",\"answer\":\"192.0.2.100\"}," +
+ "{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.254\",\"type\":\"AAAA\",\"question\":\"example.com.\",\"answer\":\"2001:db8::1\"}]"
+
+ var tests = []struct {
+ method string
+ body string
+ url string
+ response string
+ status int
+ }{
+ // Unknown resources
+ {http.MethodGet, "", "/not-found", `{"status":404,"message":"Resource not found"}`, 404},
+ {http.MethodGet, "", "/log/v1/", logResponse, 200},
+ }
+
+ for _, tt := range tests {
+ var (
+ data string
+ status int
+ err error
+ )
+ switch tt.method {
+ case http.MethodGet:
+ data, status, err = httpGet(server.URL + tt.url)
+ default:
+ t.Fatal("invalid method: " + tt.method)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got := status; status != tt.status {
+ t.Errorf("want status %d for %q, got %d", tt.status, tt.url, got)
+ }
+ if got := string(data); got != tt.response {
+ t.Errorf("want response %q for %s, got %q", tt.response, tt.url, got)
+ }
+ }
+}