1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
package sql
import (
"log"
"net"
"sync"
"time"
)
const (
// LogDiscard disables logging of DNS requests.
LogDiscard = iota
// LogAll logs all DNS requests.
LogAll
// LogHijacked only logs hijacked DNS requests.
LogHijacked
)
// Logger is a logs DNS requests to a SQL database.
type Logger struct {
mode int
queue chan Entry
client *Client
wg sync.WaitGroup
now func() time.Time
}
// Entry represents a DNS request log entry.
type Entry struct {
Time time.Time
RemoteAddr net.IP
Hijacked bool
Qtype uint16
Question string
Answers []string
}
// NewLogger creates a new logger. Persisted entries are kept according to ttl.
func NewLogger(client *Client, mode int, ttl time.Duration) *Logger {
l := &Logger{
client: client,
queue: make(chan Entry, 1024),
now: time.Now,
mode: mode,
}
if mode != LogDiscard {
go l.readQueue(ttl)
}
return l
}
// Close consumes any outstanding log requests and closes the logger.
func (l *Logger) Close() error {
l.wg.Wait()
return nil
}
// Record records the given DNS request to the log database.
func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) {
if l.mode == LogDiscard {
return
}
if l.mode == LogHijacked && !hijacked {
return
}
l.wg.Add(1)
l.queue <- Entry{
Time: l.now(),
RemoteAddr: remoteAddr,
Hijacked: hijacked,
Qtype: qtype,
Question: question,
Answers: answers,
}
}
// Get returns the n most recent persisted log entries.
func (l *Logger) Get(n int) ([]Entry, error) {
logEntries, err := l.client.ReadLog(n)
if err != nil {
return nil, err
}
ids := make(map[int64]*Entry)
entries := make([]Entry, 0, len(logEntries))
for _, le := range logEntries {
entry, ok := ids[le.ID]
if !ok {
newEntry := Entry{
Time: time.Unix(le.Time, 0).UTC(),
RemoteAddr: le.RemoteAddr,
Hijacked: le.Hijacked,
Qtype: le.Qtype,
Question: le.Question,
}
entries = append(entries, newEntry)
entry = &entries[len(entries)-1]
ids[le.ID] = entry
}
if le.Answer != "" {
entry.Answers = append(entry.Answers, le.Answer)
}
}
return entries, nil
}
func (l *Logger) readQueue(ttl time.Duration) {
for e := range l.queue {
if err := l.client.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
log.Printf("write failed: %+v: %s", e, err)
}
if ttl > 0 {
t := l.now().Add(-ttl)
if err := l.client.DeleteLogBefore(t); err != nil {
log.Printf("deleting log entries before %v failed: %s", t, err)
}
}
l.wg.Done()
}
}
|