diff options
-rw-r--r-- | cmd/zdns/main.go | 12 | ||||
-rw-r--r-- | config.go | 1 | ||||
-rw-r--r-- | http/http.go | 128 | ||||
-rw-r--r-- | http/http_test.go | 82 | ||||
-rw-r--r-- | log/logger.go | 8 | ||||
-rw-r--r-- | log/logger_test.go | 2 | ||||
-rw-r--r-- | zdnsrc | 5 |
7 files changed, 231 insertions, 7 deletions
diff --git a/cmd/zdns/main.go b/cmd/zdns/main.go index 5565ef4..623175c 100644 --- a/cmd/zdns/main.go +++ b/cmd/zdns/main.go @@ -10,6 +10,7 @@ import ( "flag" "github.com/mpolden/zdns" + "github.com/mpolden/zdns/http" "github.com/mpolden/zdns/log" "github.com/mpolden/zdns/signal" ) @@ -79,13 +80,20 @@ func (c *cli) run() { }) fatal(err) + sigHandler := signal.NewHandler(c.signal, logger) + dnsSrv, err := zdns.NewServer(logger, config) fatal(err) - - sigHandler := signal.NewHandler(c.signal, logger) sigHandler.OnReload(dnsSrv) sigHandler.OnClose(dnsSrv) c.runServer(dnsSrv) + + httpSrv := http.NewServer(logger, config.DNS.ListenHTTP) + if httpSrv != nil { + sigHandler.OnClose(httpSrv) + c.runServer(httpSrv) + } + c.wg.Wait() } @@ -37,6 +37,7 @@ type DNSOptions struct { logMode int LogTTLString string `toml:"log_ttl"` LogTTL time.Duration + ListenHTTP string `toml:"listen_http"` } // ResolverOptions controls the behaviour of resolvers. diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..2b0c5cc --- /dev/null +++ b/http/http.go @@ -0,0 +1,128 @@ +package http + +import ( + "context" + "encoding/json" + "net" + "net/http" + "time" + + "github.com/mpolden/zdns/dns" + "github.com/mpolden/zdns/log" +) + +// A Server defines paramaters for running an HTTP server. The HTTP server serves an API for inspecting cache contents +// and request log. +type Server struct { + server *http.Server + logger *log.Logger +} + +type logEntry struct { + Time string `json:"time"` + RemoteAddr net.IP `json:"remote_addr"` + Qtype string `json:"type"` + Question string `json:"question"` + Answer string `json:"answer"` +} + +type httpError struct { + err error + Status int `json:"status"` + Message string `json:"message"` +} + +// NewServer creates a new HTTP server, serving logs from the given logger and listening on listenAddr. +func NewServer(logger *log.Logger, listenAddr string) *Server { + server := &http.Server{Addr: listenAddr} + s := &Server{logger: logger, server: server} + s.server.Handler = s.handler() + return s +} + +func (s *Server) handler() http.Handler { + mux := http.NewServeMux() + mux.Handle("/log/v1/", appHandler(s.logHandler)) + //mux.Handle("/cache/v1") + mux.Handle("/", appHandler(notFoundHandler)) + return requestFilter(mux) +} + +type appHandler func(http.ResponseWriter, *http.Request) (interface{}, *httpError) + +func (fn appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + data, e := fn(w, r) + if e != nil { // e is *Error, not os.Error. + if e.Message == "" { + e.Message = e.err.Error() + } + out, err := json.Marshal(e) + if err != nil { + panic(err) + } + w.WriteHeader(e.Status) + w.Write(out) + } else if data != nil { + out, err := json.Marshal(data) + if err != nil { + panic(err) + } + w.Write(out) + } +} + +func requestFilter(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + next.ServeHTTP(w, r) + }) +} + +func notFoundHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) { + return nil, &httpError{ + Status: http.StatusNotFound, + Message: "Resource not found", + } +} + +func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}, *httpError) { + logEntries, err := s.logger.Get(100) + if err != nil { + return nil, &httpError{ + err: err, + Status: http.StatusInternalServerError, + } + } + entries := make([]logEntry, 0, len(logEntries)) + for _, entry := range logEntries { + dnsType := "" + switch entry.Qtype { + case dns.TypeA: + dnsType = "A" + case dns.TypeAAAA: + dnsType = "AAAA" + } + e := logEntry{ + Time: entry.Time.UTC().Format(time.RFC3339), + RemoteAddr: entry.RemoteAddr, + Qtype: dnsType, + Question: entry.Question, + Answer: entry.Answer, + } + entries = append(entries, e) + } + return entries, nil +} + +// Close shuts down the HTTP server. +func (s *Server) Close() error { return s.server.Shutdown(context.Background()) } + +// ListenAndServe starts the HTTP server listening on the configured address. +func (s *Server) ListenAndServe() error { + s.logger.Printf("http server listening on http://%s", s.server.Addr) + err := s.server.ListenAndServe() + if err == http.ErrServerClosed { + return nil // Do not treat server closing as an error + } + return err +} diff --git a/http/http_test.go b/http/http_test.go new file mode 100644 index 0000000..3e257f3 --- /dev/null +++ b/http/http_test.go @@ -0,0 +1,82 @@ +package http + +import ( + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/mpolden/zdns/log" +) + +func testServer() (*httptest.Server, *log.Logger) { + logger, err := log.New(ioutil.Discard, "", log.RecordOptions{ + Database: ":memory:", + }) + logger.Now = func() time.Time { return time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) } + if err != nil { + panic(err) + } + server := Server{logger: logger} + return httptest.NewServer(server.handler()), logger +} + +func httpGet(url string) (string, int, error) { + res, err := http.Get(url) + if err != nil { + return "", 0, err + } + defer res.Body.Close() + data, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", 0, err + } + return string(data), res.StatusCode, nil +} + +func TestRequests(t *testing.T) { + server, logger := testServer() + defer server.Close() + logger.Record(net.IPv4(127, 0, 0, 42), 1, "example.com.", "192.0.2.100") + logger.Record(net.IPv4(127, 0, 0, 254), 28, "example.com.", "2001:db8::1") + + var logResponse = "[{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.42\",\"type\":\"A\",\"question\":\"example.com.\",\"answer\":\"192.0.2.100\"}," + + "{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.254\",\"type\":\"AAAA\",\"question\":\"example.com.\",\"answer\":\"2001:db8::1\"}]" + + var tests = []struct { + method string + body string + url string + response string + status int + }{ + // Unknown resources + {http.MethodGet, "", "/not-found", `{"status":404,"message":"Resource not found"}`, 404}, + {http.MethodGet, "", "/log/v1/", logResponse, 200}, + } + + for _, tt := range tests { + var ( + data string + status int + err error + ) + switch tt.method { + case http.MethodGet: + data, status, err = httpGet(server.URL + tt.url) + default: + t.Fatal("invalid method: " + tt.method) + } + if err != nil { + t.Fatal(err) + } + if got := status; status != tt.status { + t.Errorf("want status %d for %q, got %d", tt.status, tt.url, got) + } + if got := string(data); got != tt.response { + t.Errorf("want response %q for %s, got %q", tt.response, tt.url, got) + } + } +} diff --git a/log/logger.go b/log/logger.go index d46fe52..99e53d1 100644 --- a/log/logger.go +++ b/log/logger.go @@ -13,7 +13,7 @@ import ( // Logger wraps a standard log.Logger and an optional log database. type Logger struct { *log.Logger - now func() time.Time + Now func() time.Time queue chan Entry db *sql.Client maintainer *maintainer @@ -47,7 +47,7 @@ func New(w io.Writer, prefix string, options RecordOptions) (*Logger, error) { logger := &Logger{ Logger: log.New(w, prefix, 0), queue: make(chan Entry, 100), - now: time.Now, + Now: time.Now, } var err error if options.Database != "" { @@ -84,7 +84,7 @@ func (m *maintainer) run(logger *Logger) { for { select { case <-ticker.C: - t := logger.now().Add(-m.ttl) + t := logger.Now().Add(-m.ttl) if err := logger.db.DeleteLogBefore(t); err != nil { logger.Printf("error deleting log entries before %v: %s", t, err) } @@ -111,7 +111,7 @@ func (l *Logger) Record(remoteAddr net.IP, qtype uint16, question, answer string return } l.queue <- Entry{ - Time: l.now(), + Time: l.Now(), RemoteAddr: remoteAddr, Qtype: qtype, Question: question, diff --git a/log/logger_test.go b/log/logger_test.go index 37166ab..c28637c 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -36,7 +36,7 @@ func TestLogPruning(t *testing.T) { t.Fatal(err) } tt := time.Now() - logger.now = func() time.Time { return tt } + logger.Now = func() time.Time { return tt } logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1") // Wait until queue is flushed @@ -55,6 +55,11 @@ # # log_mode = "" +# HTTP server for inspecting logs and cache. Setting a listening address on the +# form addr:port will enable the server. +# +# listen_http = "" + [resolver] # Set the protocol to use when sending requests to upstream resolvers. Support protocols: # |