1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
package sql
import (
"log"
"net"
"sync"
"time"
)
const (
// LogDiscard disables logging of DNS requests.
LogDiscard = iota
// LogAll logs all DNS requests.
LogAll
// LogHijacked only logs hijacked DNS requests.
LogHijacked
)
// Logger is a logger that logs DNS requests to a SQL database.
type Logger struct {
mode int
queue chan LogEntry
client *Client
wg sync.WaitGroup
now func() time.Time
}
// LogEntry represents a log entry for a DNS request.
type LogEntry struct {
Time time.Time
RemoteAddr net.IP
Hijacked bool
Qtype uint16
Question string
Answers []string
}
// NewLogger creates a new logger. Persisted entries are kept according to ttl.
func NewLogger(client *Client, mode int, ttl time.Duration) *Logger {
l := &Logger{
client: client,
queue: make(chan LogEntry, 1024),
now: time.Now,
mode: mode,
}
if mode != LogDiscard {
go l.readQueue(ttl)
}
return l
}
// Close consumes any outstanding log requests and closes the logger.
func (l *Logger) Close() error {
l.wg.Wait()
return nil
}
// Record records the given DNS request to the log database.
func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) {
if l.mode == LogDiscard {
return
}
if l.mode == LogHijacked && !hijacked {
return
}
l.wg.Add(1)
l.queue <- LogEntry{
Time: l.now(),
RemoteAddr: remoteAddr,
Hijacked: hijacked,
Qtype: qtype,
Question: question,
Answers: answers,
}
}
// Read returns the n most recent log entries.
func (l *Logger) Read(n int) ([]LogEntry, error) {
entries, err := l.client.readLog(n)
if err != nil {
return nil, err
}
ids := make(map[int64]*LogEntry)
logEntries := make([]LogEntry, 0, len(entries))
for _, le := range entries {
entry, ok := ids[le.ID]
if !ok {
newEntry := LogEntry{
Time: time.Unix(le.Time, 0).UTC(),
RemoteAddr: le.RemoteAddr,
Hijacked: le.Hijacked,
Qtype: le.Qtype,
Question: le.Question,
}
logEntries = append(logEntries, newEntry)
entry = &logEntries[len(logEntries)-1]
ids[le.ID] = entry
}
if le.Answer != "" {
entry.Answers = append(entry.Answers, le.Answer)
}
}
return logEntries, nil
}
func (l *Logger) readQueue(ttl time.Duration) {
for e := range l.queue {
if err := l.client.writeLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
log.Printf("write failed: %+v: %s", e, err)
}
if ttl > 0 {
t := l.now().Add(-ttl)
if err := l.client.deleteLogBefore(t); err != nil {
log.Printf("deleting log entries before %v failed: %s", t, err)
}
}
l.wg.Done()
}
}
|