aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dns/proxy.go4
-rw-r--r--dns/proxy_test.go2
-rw-r--r--http/http.go3
-rw-r--r--http/http_test.go10
-rw-r--r--log/logger.go11
-rw-r--r--log/logger_test.go7
-rw-r--r--sql/sql.go13
-rw-r--r--sql/sql_test.go31
8 files changed, 48 insertions, 33 deletions
diff --git a/dns/proxy.go b/dns/proxy.go
index fe868e1..d482e4c 100644
--- a/dns/proxy.go
+++ b/dns/proxy.go
@@ -68,7 +68,7 @@ type client interface {
type logger interface {
Printf(string, ...interface{})
- Record(net.IP, uint16, string, ...string)
+ Record(net.IP, bool, uint16, string, ...string)
Close() error
}
@@ -177,7 +177,7 @@ func (p *Proxy) writeMsg(w dns.ResponseWriter, msg *dns.Msg, hijacked bool) {
} else {
answers := answers(msg)
remoteAddr := net.ParseIP(ip)
- p.logger.Record(remoteAddr, msg.Question[0].Qtype, msg.Question[0].Name, answers...)
+ p.logger.Record(remoteAddr, hijacked, msg.Question[0].Qtype, msg.Question[0].Name, answers...)
}
}
w.WriteMsg(msg)
diff --git a/dns/proxy_test.go b/dns/proxy_test.go
index 4069dd3..003430d 100644
--- a/dns/proxy_test.go
+++ b/dns/proxy_test.go
@@ -55,7 +55,7 @@ type testLogger struct {
func (l *testLogger) Close() error { return nil }
func (l *testLogger) Printf(format string, v ...interface{}) {}
-func (l *testLogger) Record(remoteAddr net.IP, qtype uint16, question string, answers ...string) {
+func (l *testLogger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) {
l.question = question
l.remoteAddr = remoteAddr
}
diff --git a/http/http.go b/http/http.go
index a17fb7c..12d9e31 100644
--- a/http/http.go
+++ b/http/http.go
@@ -25,6 +25,7 @@ type entry struct {
Time string `json:"time"`
TTL int64 `json:"ttl,omitempty"`
RemoteAddr net.IP `json:"remote_addr,omitempty"`
+ Hijacked *bool `json:"hijacked,omitempty"`
Qtype string `json:"type"`
Question string `json:"question"`
Answers []string `json:"answers,omitempty"`
@@ -130,9 +131,11 @@ func (s *Server) logHandler(w http.ResponseWriter, r *http.Request) (interface{}
}
entries := make([]entry, 0, len(logEntries))
for _, le := range logEntries {
+ hijacked := le.Hijacked
entries = append(entries, entry{
Time: le.Time.UTC().Format(time.RFC3339),
RemoteAddr: le.RemoteAddr,
+ Hijacked: &hijacked,
Qtype: dns.TypeToString[le.Qtype],
Question: le.Question,
Answers: le.Answers,
diff --git a/http/http_test.go b/http/http_test.go
index 610f6f8..e964e3b 100644
--- a/http/http_test.go
+++ b/http/http_test.go
@@ -55,17 +55,17 @@ func httpGet(url string) (string, int, error) {
func TestRequests(t *testing.T) {
httpSrv, srv := testServer()
defer httpSrv.Close()
- srv.logger.Record(net.IPv4(127, 0, 0, 42), 1, "example.com.", "192.0.2.100", "192.0.2.101")
- srv.logger.Record(net.IPv4(127, 0, 0, 254), 28, "example.com.", "2001:db8::1")
+ srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101")
+ srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1")
srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200)))
srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201)))
cr1 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"},` +
`{"time":"RFC3339","ttl":60,"type":"A","question":"1.example.com.","answers":["192.0.2.200"],"rcode":"NOERROR"}]`
cr2 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"}]`
- lr1 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","type":"AAAA","question":"example.com.","answers":["2001:db8::1"]},` +
- `{"time":"RFC3339","remote_addr":"127.0.0.42","type":"A","question":"example.com.","answers":["192.0.2.101","192.0.2.100"]}]`
- lr2 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","type":"AAAA","question":"example.com.","answers":["2001:db8::1"]}]`
+ lr1 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]},` +
+ `{"time":"RFC3339","remote_addr":"127.0.0.42","hijacked":false,"type":"A","question":"example.com.","answers":["192.0.2.101","192.0.2.100"]}]`
+ lr2 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]}]`
var tests = []struct {
method string
diff --git a/log/logger.go b/log/logger.go
index 8c272eb..914e1bd 100644
--- a/log/logger.go
+++ b/log/logger.go
@@ -31,6 +31,7 @@ type RecordOptions struct {
type Entry struct {
Time time.Time
RemoteAddr net.IP
+ Hijacked bool
Qtype uint16
Question string
Answers []string
@@ -89,13 +90,14 @@ func (l *Logger) Close() error {
}
// Record records the given DNS request to the log database.
-func (l *Logger) Record(remoteAddr net.IP, qtype uint16, question string, answers ...string) {
+func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) {
if l.db == nil {
return
}
l.queue <- Entry{
Time: l.now(),
RemoteAddr: remoteAddr,
+ Hijacked: hijacked,
Qtype: qtype,
Question: question,
Answers: answers,
@@ -116,6 +118,7 @@ func (l *Logger) Get(n int) ([]Entry, error) {
newEntry := Entry{
Time: time.Unix(le.Time, 0).UTC(),
RemoteAddr: le.RemoteAddr,
+ Hijacked: le.Hijacked,
Qtype: le.Qtype,
Question: le.Question,
}
@@ -130,9 +133,9 @@ func (l *Logger) Get(n int) ([]Entry, error) {
func (l *Logger) readQueue() {
defer l.wg.Done()
- for entry := range l.queue {
- if err := l.db.WriteLog(entry.Time, entry.RemoteAddr, entry.Qtype, entry.Question, entry.Answers...); err != nil {
- l.Printf("write failed: %+v: %s", entry, err)
+ for e := range l.queue {
+ if err := l.db.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
+ l.Printf("write failed: %+v: %s", e, err)
}
}
}
diff --git a/log/logger_test.go b/log/logger_test.go
index 72916b6..7359d5b 100644
--- a/log/logger_test.go
+++ b/log/logger_test.go
@@ -13,7 +13,7 @@ func TestRecord(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1", "192.0.2.2")
+ logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1", "192.0.2.2")
// Flush queue
if err := logger.Close(); err != nil {
t.Fatal(err)
@@ -34,7 +34,7 @@ func TestAnswerMerging(t *testing.T) {
}
now := time.Date(2006, 1, 2, 15, 4, 5, 0, time.UTC)
logger.now = func() time.Time { return now }
- logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1", "192.0.2.2")
+ logger.Record(net.IPv4(192, 0, 2, 100), true, 1, "example.com.", "192.0.2.1", "192.0.2.2")
// Flush queue
if err := logger.Close(); err != nil {
t.Fatal(err)
@@ -47,6 +47,7 @@ func TestAnswerMerging(t *testing.T) {
want := []Entry{{
Time: now,
RemoteAddr: net.IPv4(192, 0, 2, 100),
+ Hijacked: true,
Qtype: 1,
Question: "example.com.",
Answers: []string{"192.0.2.2", "192.0.2.1"},
@@ -68,7 +69,7 @@ func TestLogPruning(t *testing.T) {
defer logger.Close()
tt := time.Now()
logger.now = func() time.Time { return tt }
- logger.Record(net.IPv4(192, 0, 2, 100), 1, "example.com.", "192.0.2.1")
+ logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "example.com.", "192.0.2.1")
// Wait until queue is flushed
ts := time.Now()
diff --git a/sql/sql.go b/sql/sql.go
index 806a5d8..a02c939 100644
--- a/sql/sql.go
+++ b/sql/sql.go
@@ -13,7 +13,7 @@ const schema = `
CREATE TABLE IF NOT EXISTS rr_question (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
- CONSTRAINT name_unique UNIQUE (name)
+ CONSTRAINT name_unique UNIQUE(name)
);
CREATE TABLE IF NOT EXISTS rr_answer (
@@ -37,6 +37,7 @@ CREATE TABLE IF NOT EXISTS remote_addr (
CREATE TABLE IF NOT EXISTS log (
id INTEGER PRIMARY KEY,
time INTEGER NOT NULL,
+ hijacked INTEGER NOT NULL,
remote_addr_id INTEGER NOT NULL,
rr_type_id INTEGER NOT NULL,
rr_question_id INTEGER NOT NULL,
@@ -65,6 +66,7 @@ type LogEntry struct {
ID int64 `db:"id"`
Time int64 `db:"time"`
RemoteAddr []byte `db:"remote_addr"`
+ Hijacked bool `db:"hijacked"`
Qtype uint16 `db:"type"`
Question string `db:"question"`
Answer string `db:"answer"`
@@ -94,6 +96,7 @@ func (c *Client) ReadLog(n int) ([]LogEntry, error) {
SELECT log.id AS id,
time,
remote_addr.addr AS remote_addr,
+ hijacked,
type,
rr_question.name AS question,
rr_answer.name AS answer
@@ -125,7 +128,7 @@ func getOrInsert(tx *sqlx.Tx, table, column string, value interface{}) (int64, e
}
// WriteLog writes a new entry to the log.
-func (c *Client) WriteLog(time time.Time, remoteAddr []byte, qtype uint16, question string, answers ...string) error {
+func (c *Client) WriteLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Beginx()
@@ -153,7 +156,11 @@ func (c *Client) WriteLog(time time.Time, remoteAddr []byte, qtype uint16, quest
}
answerIDs = append(answerIDs, answerID)
}
- res, err := tx.Exec("INSERT INTO log (time, remote_addr_id, rr_type_id, rr_question_id) VALUES ($1, $2, $3, $4)", time.Unix(), remoteAddrID, typeID, questionID)
+ hijackedInt := 0
+ if hijacked {
+ hijackedInt = 1
+ }
+ res, err := tx.Exec("INSERT INTO log (time, hijacked, remote_addr_id, rr_type_id, rr_question_id) VALUES ($1, $2, $3, $4, $5)", time.Unix(), hijackedInt, remoteAddrID, typeID, questionID)
if err != nil {
return err
}
diff --git a/sql/sql_test.go b/sql/sql_test.go
index 3655a66..868fecb 100644
--- a/sql/sql_test.go
+++ b/sql/sql_test.go
@@ -16,24 +16,25 @@ type rowCount struct {
var tests = []struct {
question string
qtype uint16
+ hijacked bool
answers []string
t time.Time
remoteAddr net.IP
rowCounts []rowCount
}{
- {"foo.example.com", 1, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 15, 10, 0, time.UTC), net.IPv4(192, 0, 2, 100),
+ {"foo.example.com", 1, false, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 15, 10, 0, time.UTC), net.IPv4(192, 0, 2, 100),
[]rowCount{{"rr_question", 1}, {"rr_answer", 1}, {"log", 1}, {"rr_type", 1}, {"remote_addr", 1}}},
- {"foo.example.com", 1, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 16, 20, 0, time.UTC), net.IPv4(192, 0, 2, 100),
+ {"foo.example.com", 1, true, []string{"192.0.2.1"}, time.Date(2019, 6, 15, 22, 16, 20, 0, time.UTC), net.IPv4(192, 0, 2, 100),
[]rowCount{{"rr_question", 1}, {"rr_answer", 1}, {"log", 2}, {"rr_type", 1}, {"remote_addr", 1}}},
- {"bar.example.com", 1, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 17, 30, 0, time.UTC), net.IPv4(192, 0, 2, 101),
+ {"bar.example.com", 1, false, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 17, 30, 0, time.UTC), net.IPv4(192, 0, 2, 101),
[]rowCount{{"rr_question", 2}, {"rr_answer", 2}, {"log", 3}, {"rr_type", 1}, {"remote_addr", 2}}},
- {"bar.example.com", 1, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 18, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102),
+ {"bar.example.com", 1, false, []string{"192.0.2.2"}, time.Date(2019, 6, 15, 22, 18, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102),
[]rowCount{{"rr_question", 2}, {"rr_answer", 2}, {"log", 4}, {"rr_type", 1}, {"remote_addr", 3}}},
- {"bar.example.com", 28, []string{"2001:db8::1"}, time.Date(2019, 6, 15, 23, 4, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102),
+ {"bar.example.com", 28, false, []string{"2001:db8::1"}, time.Date(2019, 6, 15, 23, 4, 40, 0, time.UTC), net.IPv4(192, 0, 2, 102),
[]rowCount{{"rr_question", 2}, {"rr_answer", 3}, {"log", 5}, {"rr_type", 2}, {"remote_addr", 3}}},
- {"bar.example.com", 28, []string{"2001:db8::2", "2001:db8::3"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102),
+ {"bar.example.com", 28, false, []string{"2001:db8::2", "2001:db8::3"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102),
[]rowCount{{"rr_question", 2}, {"rr_answer", 5}, {"log", 6}, {"rr_type", 2}, {"remote_addr", 3}}},
- {"baz.example.com", 28, []string{"2001:db8::4"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102),
+ {"baz.example.com", 28, false, []string{"2001:db8::4"}, time.Date(2019, 6, 15, 23, 35, 0, 0, time.UTC), net.IPv4(192, 0, 2, 102),
[]rowCount{{"rr_question", 3}, {"rr_answer", 6}, {"log", 7}, {"rr_type", 2}, {"remote_addr", 3}}},
}
@@ -56,8 +57,8 @@ func count(t *testing.T, client *Client, query string, args ...interface{}) int
func TestWriteLog(t *testing.T) {
c := testClient()
for i, tt := range tests {
- if err := c.WriteLog(tt.t, tt.remoteAddr, tt.qtype, tt.question, tt.answers...); err != nil {
- t.Errorf("#%d: WriteLog(%q, %s, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.qtype, tt.question, tt.answers, err)
+ if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
+ t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err)
}
for _, rowCount := range tt.rowCounts {
rows := count(t, c, "SELECT COUNT(*) FROM "+rowCount.table+" LIMIT 1")
@@ -71,8 +72,8 @@ func TestWriteLog(t *testing.T) {
func TestReadLog(t *testing.T) {
c := testClient()
for i, tt := range tests {
- if err := c.WriteLog(tt.t, tt.remoteAddr, tt.qtype, tt.question, tt.answers...); err != nil {
- t.Fatalf("#%d: WriteLog(%q, %s, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.qtype, tt.question, tt.answers, err)
+ if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
+ t.Fatalf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err)
}
}
allEntries := [][]LogEntry{
@@ -84,7 +85,7 @@ func TestReadLog(t *testing.T) {
{{ID: 5, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::1", Time: 1560639880, RemoteAddr: net.IPv4(192, 0, 2, 102)}},
{{ID: 4, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637120, RemoteAddr: net.IPv4(192, 0, 2, 102)}},
{{ID: 3, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637050, RemoteAddr: net.IPv4(192, 0, 2, 101)}},
- {{ID: 2, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636980, RemoteAddr: net.IPv4(192, 0, 2, 100)}},
+ {{ID: 2, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636980, RemoteAddr: net.IPv4(192, 0, 2, 100), Hijacked: true}},
{{ID: 1, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636910, RemoteAddr: net.IPv4(192, 0, 2, 100)}},
}
for n := 1; n <= len(allEntries); n++ {
@@ -105,8 +106,8 @@ func TestReadLog(t *testing.T) {
func TestDeleteLogBefore(t *testing.T) {
c := testClient()
for i, tt := range tests {
- if err := c.WriteLog(tt.t, tt.remoteAddr, tt.qtype, tt.question, tt.answers...); err != nil {
- t.Fatalf("#%d: WriteLog(%s, %s, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.question, tt.answers, err)
+ if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
+ t.Fatalf("#%d: WriteLog(%s, %s, %t, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.question, tt.answers, err)
}
}
u := tests[1].t.Add(time.Second)
@@ -148,7 +149,7 @@ func TestInterleavedRW(t *testing.T) {
go func() {
defer wg.Done()
for range ch {
- err = c.WriteLog(time.Now(), net.IPv4(127, 0, 0, 1), 1, "example.com.", "192.0.2.1")
+ err = c.WriteLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1")
}
}()
ch <- true