aboutsummaryrefslogtreecommitdiffstats
path: root/sql/logger.go
blob: 490201135652ec4a0512b603b950692b3b516710 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package sql

import (
	"log"
	"net"
	"sync"
	"time"
)

const (
	// LogDiscard disables logging of DNS requests.
	LogDiscard = iota
	// LogAll logs all DNS requests.
	LogAll
	// LogHijacked only logs hijacked DNS requests.
	LogHijacked
)

// Logger is a logs DNS requests to a SQL database.
type Logger struct {
	mode   int
	queue  chan Entry
	client *Client
	wg     sync.WaitGroup
	now    func() time.Time
}

// Entry represents a DNS request log entry.
type Entry struct {
	Time       time.Time
	RemoteAddr net.IP
	Hijacked   bool
	Qtype      uint16
	Question   string
	Answers    []string
}

// NewLogger creates a new logger. Persisted entries are kept according to ttl.
func NewLogger(client *Client, mode int, ttl time.Duration) *Logger {
	l := &Logger{
		client: client,
		queue:  make(chan Entry, 1024),
		now:    time.Now,
		mode:   mode,
	}
	if mode != LogDiscard {
		go l.readQueue(ttl)
	}
	return l
}

// Close consumes any outstanding log requests and closes the logger.
func (l *Logger) Close() error {
	l.wg.Wait()
	return nil
}

// Record records the given DNS request to the log database.
func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question string, answers ...string) {
	if l.mode == LogDiscard {
		return
	}
	if l.mode == LogHijacked && !hijacked {
		return
	}
	l.wg.Add(1)
	l.queue <- Entry{
		Time:       l.now(),
		RemoteAddr: remoteAddr,
		Hijacked:   hijacked,
		Qtype:      qtype,
		Question:   question,
		Answers:    answers,
	}
}

// Get returns the n most recent persisted log entries.
func (l *Logger) Get(n int) ([]Entry, error) {
	logEntries, err := l.client.ReadLog(n)
	if err != nil {
		return nil, err
	}
	ids := make(map[int64]*Entry)
	entries := make([]Entry, 0, len(logEntries))
	for _, le := range logEntries {
		entry, ok := ids[le.ID]
		if !ok {
			newEntry := Entry{
				Time:       time.Unix(le.Time, 0).UTC(),
				RemoteAddr: le.RemoteAddr,
				Hijacked:   le.Hijacked,
				Qtype:      le.Qtype,
				Question:   le.Question,
			}
			entries = append(entries, newEntry)
			entry = &entries[len(entries)-1]
			ids[le.ID] = entry
		}
		if le.Answer != "" {
			entry.Answers = append(entry.Answers, le.Answer)
		}
	}
	return entries, nil
}

func (l *Logger) readQueue(ttl time.Duration) {
	for e := range l.queue {
		if err := l.client.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
			log.Printf("write failed: %+v: %s", e, err)
		}
		if ttl > 0 {
			t := l.now().Add(-ttl)
			if err := l.client.DeleteLogBefore(t); err != nil {
				log.Printf("deleting log entries before %v failed: %s", t, err)
			}
		}
		l.wg.Done()
	}
}