diff options
-rw-r--r-- | dns/proxy.go | 2 | ||||
-rw-r--r-- | http/http.go | 14 | ||||
-rw-r--r-- | http/http_test.go | 6 | ||||
-rw-r--r-- | log/logger.go | 30 | ||||
-rw-r--r-- | log/logger_test.go | 34 |
5 files changed, 63 insertions, 23 deletions
diff --git a/dns/proxy.go b/dns/proxy.go index bca2a56..a159b26 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -14,6 +14,8 @@ const ( TypeA = dns.TypeA // TypeAAAA represents the resource record type AAAA, an IPv6 address. TypeAAAA = dns.TypeAAAA + // TypeMX represents the resource record type MX, a mail exchange address. + TypeMX = dns.TypeMX // LogDiscard disables logging of DNS requests LogDiscard = iota // LogAll logs all DNS requests diff --git a/http/http.go b/http/http.go index 2b0c5cc..dac035c 100644 --- a/http/http.go +++ b/http/http.go @@ -19,11 +19,11 @@ type Server struct { } type logEntry struct { - Time string `json:"time"` - RemoteAddr net.IP `json:"remote_addr"` - Qtype string `json:"type"` - Question string `json:"question"` - Answer string `json:"answer"` + Time string `json:"time"` + RemoteAddr net.IP `json:"remote_addr"` + Qtype string `json:"type"` + Question string `json:"question"` + Answers []string `json:"answers"` } type httpError struct { @@ -101,13 +101,15 @@ func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{} dnsType = "A" case dns.TypeAAAA: dnsType = "AAAA" + case dns.TypeMX: + dnsType = "MX" } e := logEntry{ Time: entry.Time.UTC().Format(time.RFC3339), RemoteAddr: entry.RemoteAddr, Qtype: dnsType, Question: entry.Question, - Answer: entry.Answer, + Answers: entry.Answers, } entries = append(entries, e) } diff --git a/http/http_test.go b/http/http_test.go index 61ea70a..94708f8 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -39,11 +39,11 @@ func httpGet(url string) (string, int, error) { func TestRequests(t *testing.T) { server, logger := testServer() defer server.Close() - logger.Record(net.IPv4(127, 0, 0, 42), 1, "example.com.", "192.0.2.100") + logger.Record(net.IPv4(127, 0, 0, 42), 1, "example.com.", "192.0.2.100", "192.0.2.101") logger.Record(net.IPv4(127, 0, 0, 254), 28, "example.com.", "2001:db8::1") - var logResponse = "[{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.254\",\"type\":\"AAAA\",\"question\":\"example.com.\",\"answer\":\"2001:db8::1\"}," + - "{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.42\",\"type\":\"A\",\"question\":\"example.com.\",\"answer\":\"192.0.2.100\"}]" + var logResponse = "[{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.254\",\"type\":\"AAAA\",\"question\":\"example.com.\",\"answers\":[\"2001:db8::1\"]}," + + "{\"time\":\"2006-01-02T15:04:05Z\",\"remote_addr\":\"127.0.0.42\",\"type\":\"A\",\"question\":\"example.com.\",\"answers\":[\"192.0.2.101\",\"192.0.2.100\"]}]" var tests = []struct { method string diff --git a/log/logger.go b/log/logger.go index 6030997..47c33da 100644 --- a/log/logger.go +++ b/log/logger.go @@ -33,8 +33,7 @@ type Entry struct { RemoteAddr net.IP Qtype uint16 Question string - Answer string - answers []string + Answers []string } type maintainer struct { @@ -116,7 +115,7 @@ func (l *Logger) Record(remoteAddr net.IP, qtype uint16, question string, answer RemoteAddr: remoteAddr, Qtype: qtype, Question: question, - answers: answers, + Answers: answers, } } @@ -126,15 +125,22 @@ func (l *Logger) Get(n int) ([]Entry, error) { if err != nil { return nil, err } - entries := make([]Entry, len(logEntries)) - for i, le := range logEntries { - entries[i] = Entry{ - Time: time.Unix(le.Time, 0), - RemoteAddr: le.RemoteAddr, - Qtype: le.Qtype, - Question: le.Question, - Answer: le.Answer, + ids := make(map[int64]*Entry) + entries := make([]Entry, 0, len(logEntries)) + for _, le := range logEntries { + entry, ok := ids[le.ID] + if !ok { + newEntry := Entry{ + Time: time.Unix(le.Time, 0).UTC(), + RemoteAddr: le.RemoteAddr, + Qtype: le.Qtype, + Question: le.Question, + } + entries = append(entries, newEntry) + entry = &entries[len(entries)-1] + ids[le.ID] = entry } + entry.Answers = append(entry.Answers, le.Answer) } return entries, nil } @@ -142,7 +148,7 @@ func (l *Logger) Get(n int) ([]Entry, error) { func (l *Logger) readQueue() { defer l.wg.Done() for entry := range l.queue { - if err := l.db.WriteLog(entry.Time, entry.RemoteAddr, entry.Qtype, entry.Question, entry.answers...); err != nil { + if err := l.db.WriteLog(entry.Time, entry.RemoteAddr, entry.Qtype, entry.Question, entry.Answers...); err != nil { l.Printf("write failed: %+v: %s", entry, err) } } diff --git a/log/logger_test.go b/log/logger_test.go index 2e9a351..1da1ead 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -3,6 +3,7 @@ package log import ( "net" "os" + "reflect" "testing" "time" ) @@ -17,15 +18,44 @@ func TestRecord(t *testing.T) { if err := logger.Close(); err != nil { t.Fatal(err) } - entries, err := logger.db.ReadLog(2) + logEntries, err := logger.db.ReadLog(1) if err != nil { t.Fatal(err) } - if want, got := 2, len(entries); want != got { + if want, got := 2, len(logEntries); want != got { t.Errorf("len(entries) = %d, want %d", got, want) } } +func TestAnswerMerging(t *testing.T) { + logger, err := New(os.Stderr, "test: ", RecordOptions{Database: ":memory:"}) + if err != nil { + t.Fatal(err) + } + now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) + logger.Now = func() time.Time { return now } + logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1", "192.0.2.2") + // Flush queue + if err := logger.Close(); err != nil { + t.Fatal(err) + } + // Multi-answer log entry is merged + got, err := logger.Get(1) + if err != nil { + t.Fatal(err) + } + want := []Entry{{ + Time: now, + RemoteAddr: net.IPv4(192, 0, 2, 100), + Qtype: 1, + Question: "example.com.", + Answers: []string{"192.0.2.2", "192.0.2.1"}, + }} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get(1) = %+v, want %+v", got, want) + } +} + func TestLogPruning(t *testing.T) { logger, err := New(os.Stderr, "test: ", RecordOptions{ Database: ":memory:", |