diff options
-rw-r--r-- | dns/proxy.go | 4 | ||||
-rw-r--r-- | dns/proxy_test.go | 2 | ||||
-rw-r--r-- | http/http.go | 3 | ||||
-rw-r--r-- | http/http_test.go | 10 | ||||
-rw-r--r-- | log/logger.go | 11 | ||||
-rw-r--r-- | log/logger_test.go | 7 | ||||
-rw-r--r-- | sql/sql.go | 13 | ||||
-rw-r--r-- | sql/sql_test.go | 31 |
8 files changed, 48 insertions, 33 deletions
diff --git a/dns/proxy.go b/dns/proxy.go index fe868e1..d482e4c 100644 --- a/dns/proxy.go +++ b/dns/proxy.go @@ -68,7 +68,7 @@ type client interface { type logger interface { Printf(string, ...interface{}) - Record(net.IP, uint16, string, ...string) + Record(net.IP, bool, uint16, string, ...string) Close() error } @@ -177,7 +177,7 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) { } else { answers := answers(msg) remoteAddr := net.ParseIP(ip) - p.logger.Record(remoteAddr, msg.Question[0].Qtype, msg.Question[0].Name, answers...) + p.logger.Record(remoteAddr, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, answers...) } } w.WriteMsg(msg) diff --git a/dns/proxy_test.go b/dns/proxy_test.go index 4069dd3..003430d 100644 --- a/dns/proxy_test.go +++ b/dns/proxy_test.go @@ -55,7 +55,7 @@ type testLogger struct { func (l *testLogger) Close() error { return nil } func (l *testLogger) Printf(format string, v ...interface{}) {} -func (l *testLogger) Record(remoteAddr net.IP, qtype uint16, question string, answers ...string) { +func (l *testLogger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) { l.question = question l.remoteAddr = remoteAddr } diff --git a/http/http.go b/http/http.go index a17fb7c..12d9e31 100644 --- a/http/http.go +++ b/http/http.go @@ -25,6 +25,7 @@ type entry struct { Time string `json:"time"` TTL int64 `json:"ttl,omitempty"` RemoteAddr net.IP `json:"remote_addr,omitempty"` + Hijacked *bool `json:"hijacked,omitempty"` Qtype string `json:"type"` Question string `json:"question"` Answers []string `json:"answers,omitempty"` @@ -130,9 +131,11 @@ func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{} } entries := make([]entry, 0, len(logEntries)) for _, le := range logEntries { + hijacked := le.Hijacked entries = append(entries, entry{ Time: le.Time.UTC().Format(time.RFC3339), RemoteAddr: le.RemoteAddr, + Hijacked: &hijacked, Qtype: dns.TypeToString[le.Qtype], Question: le.Question, Answers: le.Answers, diff --git a/http/http_test.go b/http/http_test.go index 610f6f8..e964e3b 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -55,17 +55,17 @@ func httpGet(url string) (string, int, error) { func TestRequests(t *testing.T) { httpSrv, srv := testServer() defer httpSrv.Close() - srv.logger.Record(net.IPv4(127, 0, 0, 42), 1, "example.com.", "192.0.2.100", "192.0.2.101") - srv.logger.Record(net.IPv4(127, 0, 0, 254), 28, "example.com.", "2001:db8::1") + srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101") + srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1") srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200))) srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201))) cr1 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"},` + `{"time":"RFC3339","ttl":60,"type":"A","question":"1.example.com.","answers":["192.0.2.200"],"rcode":"NOERROR"}]` cr2 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"}]` - lr1 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","type":"AAAA","question":"example.com.","answers":["2001:db8::1"]},` + - `{"time":"RFC3339","remote_addr":"127.0.0.42","type":"A","question":"example.com.","answers":["192.0.2.101","192.0.2.100"]}]` - lr2 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","type":"AAAA","question":"example.com.","answers":["2001:db8::1"]}]` + lr1 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]},` + + `{"time":"RFC3339","remote_addr":"127.0.0.42","hijacked":false,"type":"A","question":"example.com.","answers":["192.0.2.101","192.0.2.100"]}]` + lr2 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]}]` var tests = []struct { method string diff --git a/log/logger.go b/log/logger.go index 8c272eb..914e1bd 100644 --- a/log/logger.go +++ b/log/logger.go @@ -31,6 +31,7 @@ type RecordOptions struct { type Entry struct { Time time.Time RemoteAddr net.IP + Hijacked bool Qtype uint16 Question string Answers []string @@ -89,13 +90,14 @@ func (l *Logger) Close() error { } // Record records the given DNS request to the log database. -func (l *Logger) Record(remoteAddr net.IP, qtype uint16, question string, answers ...string) { +func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) { if l.db == nil { return } l.queue <- Entry{ Time: l.now(), RemoteAddr: remoteAddr, + Hijacked: hijacked, Qtype: qtype, Question: question, Answers: answers, @@ -116,6 +118,7 @@ func (l *Logger) Get(n int) ([]Entry, error) { newEntry := Entry{ Time: time.Unix(le.Time, 0).UTC(), RemoteAddr: le.RemoteAddr, + Hijacked: le.Hijacked, Qtype: le.Qtype, Question: le.Question, } @@ -130,9 +133,9 @@ func (l *Logger) Get(n int) ([]Entry, error) { func (l *Logger) readQueue() { defer l.wg.Done() - for entry := range l.queue { - if err := l.db.WriteLog(entry.Time, entry.RemoteAddr, entry.Qtype, entry.Question, entry.Answers...); err != nil { - l.Printf("write failed: %+v: %s", entry, err) + for e := range l.queue { + if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { + l.Printf("write failed: %+v: %s", e, err) } } } diff --git a/log/logger_test.go b/log/logger_test.go index 72916b6..7359d5b 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -13,7 +13,7 @@ func TestRecord(t *testing.T) { if err != nil { t.Fatal(err) } - logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1", "192.0.2.2") + logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2") // Flush queue if err := logger.Close(); err != nil { t.Fatal(err) @@ -34,7 +34,7 @@ func TestAnswerMerging(t *testing.T) { } now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC) logger.now = func() time.Time { return now } - logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1", "192.0.2.2") + logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2") // Flush queue if err := logger.Close(); err != nil { t.Fatal(err) @@ -47,6 +47,7 @@ func TestAnswerMerging(t *testing.T) { want := []Entry{{ Time: now, RemoteAddr: net.IPv4(192, 0, 2, 100), + Hijacked: true, Qtype: 1, Question: "example.com.", Answers: []string{"192.0.2.2", "192.0.2.1"}, @@ -68,7 +69,7 @@ func TestLogPruning(t *testing.T) { defer logger.Close() tt := time.Now() logger.now = func() time.Time { return tt } - logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1") + logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1") // Wait until queue is flushed ts := time.Now() @@ -13,7 +13,7 @@ const schema = ` CREATE TABLE IF NOT EXISTS rr_question ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, - CONSTRAINT name_unique UNIQUE (name) + CONSTRAINT name_unique UNIQUE(name) ); CREATE TABLE IF NOT EXISTS rr_answer ( @@ -37,6 +37,7 @@ CREATE TABLE IF NOT EXISTS remote_addr ( CREATE TABLE IF NOT EXISTS log ( id INTEGER PRIMARY KEY, time INTEGER NOT NULL, + hijacked INTEGER NOT NULL, remote_addr_id INTEGER NOT NULL, rr_type_id INTEGER NOT NULL, rr_question_id INTEGER NOT NULL, @@ -65,6 +66,7 @@ type LogEntry struct { ID int64 `db:"id"` Time int64 `db:"time"` RemoteAddr []byte `db:"remote_addr"` + Hijacked bool `db:"hijacked"` Qtype uint16 `db:"type"` Question string `db:"question"` Answer string `db:"answer"` @@ -94,6 +96,7 @@ func (c *Client) ReadLog(n int) ([]LogEntry, error) { SELECT log.id AS id, time, remote_addr.addr AS remote_addr, + hijacked, type, rr_question.name AS question, rr_answer.name AS answer @@ -125,7 +128,7 @@ func getOrInsert(tx *sqlx.Tx, table, column string, value interface{}) (int64, e } // WriteLog writes a new entry to the log. -func (c *Client) WriteLog(time time.Time, remoteAddr []byte, qtype uint16, question string, answers ...string) error { +func (c *Client) WriteLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Beginx() @@ -153,7 +156,11 @@ func (c *Client) WriteLog(time time.Time, remoteAddr []byte, qtype uint16, quest } answerIDs = append(answerIDs, answerID) } - res, err := tx.Exec("INSERT INTO log (time, remote_addr_id, rr_type_id, rr_question_id) VALUES ($1, $2, $3, $4)", time.Unix(), remoteAddrID, typeID, questionID) + hijackedInt := 0 + if hijacked { + hijackedInt = 1 + } + res, err := tx.Exec("INSERT INTO log (time, hijacked, remote_addr_id, rr_type_id, rr_question_id) VALUES ($1, $2, $3, $4, $5)", time.Unix(), hijackedInt, remoteAddrID, typeID, questionID) if err != nil { return err } diff --git a/sql/sql_test.go b/sql/sql_test.go index 3655a66..868fecb 100644 --- a/sql/sql_test.go +++ b/sql/sql_test.go @@ -16,24 +16,25 @@ type rowCount struct { var tests = []struct { question string qtype uint16 + hijacked bool answers []string t time.Time remoteAddr net.IP rowCounts []rowCount }{ - {"foo.example.com", 1, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 15, 10, 0, time.UTC), net.IPv4(192, 0, 2, 100), + {"foo.example.com", 1, false, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 15, 10, 0, time.UTC), net.IPv4(192, 0, 2, 100), []rowCount{{"rr_question", 1}, {"rr_answer", 1}, {"log", 1}, {"rr_type", 1}, {"remote_addr", 1}}}, - {"foo.example.com", 1, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 16, 20, 0, time.UTC), net.IPv4(192, 0, 2, 100), + {"foo.example.com", 1, true, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 16, 20, 0, time.UTC), net.IPv4(192, 0, 2, 100), []rowCount{{"rr_question", 1}, {"rr_answer", 1}, {"log", 2}, {"rr_type", 1}, {"remote_addr", 1}}}, - {"bar.example.com", 1, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 17, 30, 0, time.UTC), net.IPv4(192, 0, 2, 101), + {"bar.example.com", 1, false, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 17, 30, 0, time.UTC), net.IPv4(192, 0, 2, 101), []rowCount{{"rr_question", 2}, {"rr_answer", 2}, {"log", 3}, {"rr_type", 1}, {"remote_addr", 2}}}, - {"bar.example.com", 1, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 18, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102), + {"bar.example.com", 1, false, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 18, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102), []rowCount{{"rr_question", 2}, {"rr_answer", 2}, {"log", 4}, {"rr_type", 1}, {"remote_addr", 3}}}, - {"bar.example.com", 28, []string{"2001:db8::1"}, time.Date(2019, 6, 15, 23, 4, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102), + {"bar.example.com", 28, false, []string{"2001:db8::1"}, time.Date(2019, 6, 15, 23, 4, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102), []rowCount{{"rr_question", 2}, {"rr_answer", 3}, {"log", 5}, {"rr_type", 2}, {"remote_addr", 3}}}, - {"bar.example.com", 28, []string{"2001:db8::2", "2001:db8::3"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), + {"bar.example.com", 28, false, []string{"2001:db8::2", "2001:db8::3"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), []rowCount{{"rr_question", 2}, {"rr_answer", 5}, {"log", 6}, {"rr_type", 2}, {"remote_addr", 3}}}, - {"baz.example.com", 28, []string{"2001:db8::4"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), + {"baz.example.com", 28, false, []string{"2001:db8::4"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102), []rowCount{{"rr_question", 3}, {"rr_answer", 6}, {"log", 7}, {"rr_type", 2}, {"remote_addr", 3}}}, } @@ -56,8 +57,8 @@ func count(t *testing.T, client *Client, query string, args ...interface{}) int func TestWriteLog(t *testing.T) { c := testClient() for i, tt := range tests { - if err := c.WriteLog(tt.t, tt.remoteAddr, tt.qtype, tt.question, tt.answers...); err != nil { - t.Errorf("#%d: WriteLog(%q, %s, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.qtype, tt.question, tt.answers, err) + if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { + t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) } for _, rowCount := range tt.rowCounts { rows := count(t, c, "SELECT COUNT(*) FROM "+rowCount.table+" LIMIT 1") @@ -71,8 +72,8 @@ func TestWriteLog(t *testing.T) { func TestReadLog(t *testing.T) { c := testClient() for i, tt := range tests { - if err := c.WriteLog(tt.t, tt.remoteAddr, tt.qtype, tt.question, tt.answers...); err != nil { - t.Fatalf("#%d: WriteLog(%q, %s, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.qtype, tt.question, tt.answers, err) + if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { + t.Fatalf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) } } allEntries := [][]LogEntry{ @@ -84,7 +85,7 @@ func TestReadLog(t *testing.T) { {{ID: 5, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::1", Time: 1560639880, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, {{ID: 4, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637120, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, {{ID: 3, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637050, RemoteAddr: net.IPv4(192, 0, 2, 101)}}, - {{ID: 2, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636980, RemoteAddr: net.IPv4(192, 0, 2, 100)}}, + {{ID: 2, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636980, RemoteAddr: net.IPv4(192, 0, 2, 100), Hijacked: true}}, {{ID: 1, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636910, RemoteAddr: net.IPv4(192, 0, 2, 100)}}, } for n := 1; n <= len(allEntries); n++ { @@ -105,8 +106,8 @@ func TestReadLog(t *testing.T) { func TestDeleteLogBefore(t *testing.T) { c := testClient() for i, tt := range tests { - if err := c.WriteLog(tt.t, tt.remoteAddr, tt.qtype, tt.question, tt.answers...); err != nil { - t.Fatalf("#%d: WriteLog(%s, %s, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.question, tt.answers, err) + if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { + t.Fatalf("#%d: WriteLog(%s, %s, %t, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.question, tt.answers, err) } } u := tests[1].t.Add(time.Second) @@ -148,7 +149,7 @@ func TestInterleavedRW(t *testing.T) { go func() { defer wg.Done() for range ch { - err = c.WriteLog(time.Now(), net.IPv4(127, 0, 0, 1), 1, "example.com.", "192.0.2.1") + err = c.WriteLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1") } }() ch <- true |