diff options
author | Martin Polden <mpolden@mpolden.no> | 2020-01-11 23:19:22 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2020-01-11 23:19:22 +0100 |
commit | 0281da919acf94402047fd2e185a82b52fad1c23 (patch) | |
tree | e879eece3a96b14220453f5401a4e355a18600cb /sql | |
parent | 644b439d17a92d16c06e298148b0a69599205ef3 (diff) |
Reduce visibility of database methods
Diffstat (limited to 'sql')
-rw-r--r-- | sql/logger.go | 36 | ||||
-rw-r--r-- | sql/logger_test.go | 14 | ||||
-rw-r--r-- | sql/sql.go | 14 | ||||
-rw-r--r-- | sql/sql_test.go | 40 |
4 files changed, 50 insertions, 54 deletions
diff --git a/sql/logger.go b/sql/logger.go index 4902011..0c2fd6a 100644 --- a/sql/logger.go +++ b/sql/logger.go @@ -16,17 +16,17 @@ const ( LogHijacked ) -// Logger is a logs DNS requests to a SQL database. +// Logger is a logger that logs DNS requests to a SQL database. type Logger struct { mode int - queue chan Entry + queue chan LogEntry client *Client wg sync.WaitGroup now func() time.Time } -// Entry represents a DNS request log entry. -type Entry struct { +// LogEntry represents a log entry for a DNS request. +type LogEntry struct { Time time.Time RemoteAddr net.IP Hijacked bool @@ -39,7 +39,7 @@ type Entry struct { func NewLogger(client *Client, mode int, ttl time.Duration) *Logger { l := &Logger{ client: client, - queue: make(chan Entry, 1024), + queue: make(chan LogEntry, 1024), now: time.Now, mode: mode, } @@ -64,7 +64,7 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question return } l.wg.Add(1) - l.queue <- Entry{ + l.queue <- LogEntry{ Time: l.now(), RemoteAddr: remoteAddr, Hijacked: hijacked, @@ -74,43 +74,43 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question } } -// Get returns the n most recent persisted log entries. -func (l *Logger) Get(n int) ([]Entry, error) { - logEntries, err := l.client.ReadLog(n) +// Read returns the n most recent log entries. +func (l *Logger) Read(n int) ([]LogEntry, error) { + entries, err := l.client.readLog(n) if err != nil { return nil, err } - ids := make(map[int64]*Entry) - entries := make([]Entry, 0, len(logEntries)) - for _, le := range logEntries { + ids := make(map[int64]*LogEntry) + logEntries := make([]LogEntry, 0, len(entries)) + for _, le := range entries { entry, ok := ids[le.ID] if !ok { - newEntry := Entry{ + newEntry := LogEntry{ Time: time.Unix(le.Time, 0).UTC(), RemoteAddr: le.RemoteAddr, Hijacked: le.Hijacked, Qtype: le.Qtype, Question: le.Question, } - entries = append(entries, newEntry) - entry = &entries[len(entries)-1] + logEntries = append(logEntries, newEntry) + entry = &logEntries[len(logEntries)-1] ids[le.ID] = entry } if le.Answer != "" { entry.Answers = append(entry.Answers, le.Answer) } } - return entries, nil + return logEntries, nil } func (l *Logger) readQueue(ttl time.Duration) { for e := range l.queue { - if err := l.client.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { + if err := l.client.writeLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil { log.Printf("write failed: %+v: %s", e, err) } if ttl > 0 { t := l.now().Add(-ttl) - if err := l.client.DeleteLogBefore(t); err != nil { + if err := l.client.deleteLogBefore(t); err != nil { log.Printf("deleting log entries before %v failed: %s", t, err) } } diff --git a/sql/logger_test.go b/sql/logger_test.go index ab6c94b..1d189c8 100644 --- a/sql/logger_test.go +++ b/sql/logger_test.go @@ -15,7 +15,7 @@ func TestRecord(t *testing.T) { if err := logger.Close(); err != nil { t.Fatal(err) } - logEntries, err := logger.client.ReadLog(1) + logEntries, err := logger.client.readLog(1) if err != nil { t.Fatal(err) } @@ -48,7 +48,7 @@ func TestMode(t *testing.T) { if err := logger.Close(); err != nil { // Flush t.Fatal(err) } - entries, err := logger.Get(1) + entries, err := logger.Read(1) if err != nil { t.Fatal(err) } @@ -69,11 +69,11 @@ func TestAnswerMerging(t *testing.T) { t.Fatal(err) } // Multi-answer log entries are merged - got, err := logger.Get(2) + got, err := logger.Read(2) if err != nil { t.Fatal(err) } - want := []Entry{ + want := []LogEntry{ { Time: now, RemoteAddr: net.IPv4(192, 0, 2, 100), @@ -103,10 +103,10 @@ func TestLogPruning(t *testing.T) { // Wait until queue is flushed ts := time.Now() - var entries []Entry + var entries []LogEntry var err error for len(entries) == 0 { - entries, err = logger.Get(1) + entries, err = logger.Read(1) if err != nil { t.Fatal(err) } @@ -121,7 +121,7 @@ func TestLogPruning(t *testing.T) { // Trigger pruning by recording another entry logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "2.example.com.", "192.0.2.2") for len(entries) > 1 { - entries, err = logger.Get(2) + entries, err = logger.Read(2) if err != nil { t.Fatal(err) } @@ -68,8 +68,7 @@ type Client struct { mu sync.RWMutex } -// LogEntry represents an entry in the log. -type LogEntry struct { +type logEntry struct { ID int64 `db:"id"` Time int64 `db:"time"` RemoteAddr []byte `db:"remote_addr"` @@ -100,8 +99,7 @@ func New(filename string) (*Client, error) { return &Client{db: db}, nil } -// ReadLog reads the n most recent entries from the log. -func (c *Client) ReadLog(n int) ([]LogEntry, error) { +func (c *Client) readLog(n int) ([]logEntry, error) { c.mu.RLock() defer c.mu.RUnlock() query := ` @@ -121,7 +119,7 @@ LEFT JOIN rr_answer ON rr_answer.id = log_rr_answer.rr_answer_id WHERE log.id IN (SELECT id FROM log ORDER BY time DESC, id DESC LIMIT $1) ORDER BY time DESC, rr_answer.id DESC ` - var entries []LogEntry + var entries []logEntry err := c.db.Select(&entries, query, n) return entries, err } @@ -139,8 +137,7 @@ func getOrInsert(tx *sqlx.Tx, table, column string, value interface{}) (int64, e return id, err } -// WriteLog writes a new entry to the log. -func (c *Client) WriteLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error { +func (c *Client) writeLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Beginx() @@ -188,8 +185,7 @@ func (c *Client) WriteLog(time time.Time, remoteAddr []byte, hijacked bool, qtyp return tx.Commit() } -// DeleteLogBefore deletes all log entries occurring before time t. -func (c *Client) DeleteLogBefore(t time.Time) (err error) { +func (c *Client) deleteLogBefore(t time.Time) (err error) { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Beginx() diff --git a/sql/sql_test.go b/sql/sql_test.go index efa9ba7..6f8bbbe 100644 --- a/sql/sql_test.go +++ b/sql/sql_test.go @@ -58,10 +58,18 @@ func count(t *testing.T, client *Client, query string, args ...interface{}) int return rows } +func writeTests(c *Client, t *testing.T) { + for i, tt := range tests { + if err := c.writeLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { + t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) + } + } +} + func TestWriteLog(t *testing.T) { c := testClient() for i, tt := range tests { - if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { + if err := c.writeLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) } for _, rowCount := range tt.rowCounts { @@ -75,12 +83,8 @@ func TestWriteLog(t *testing.T) { func TestReadLog(t *testing.T) { c := testClient() - for i, tt := range tests { - if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { - t.Fatalf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err) - } - } - allEntries := [][]LogEntry{ + writeTests(c, t) + allEntries := [][]logEntry{ {{ID: 8, Question: "baz.example.com", Qtype: 28, Time: 1560647100, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, {{ID: 7, Question: "baz.example.com", Qtype: 28, Answer: "2001:db8::4", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}}, { @@ -94,11 +98,11 @@ func TestReadLog(t *testing.T) { {{ID: 1, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636910, RemoteAddr: net.IPv4(192, 0, 2, 100)}}, } for n := 1; n <= len(allEntries); n++ { - var want []LogEntry + var want []logEntry for _, entries := range allEntries[:n] { want = append(want, entries...) } - got, err := c.ReadLog(n) + got, err := c.readLog(n) if len(got) != len(want) { t.Errorf("len(got) = %d, want %d", len(got), len(want)) } @@ -118,17 +122,13 @@ func TestReadLog(t *testing.T) { func TestDeleteLogBefore(t *testing.T) { c := testClient() - for i, tt := range tests { - if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil { - t.Fatalf("#%d: WriteLog(%s, %s, %t, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.question, tt.answers, err) - } - } + writeTests(c, t) u := tests[1].t.Add(time.Second) - if err := c.DeleteLogBefore(u); err != nil { + if err := c.deleteLogBefore(u); err != nil { t.Fatalf("DeleteBefore(%s) = %v, want %v", u, err, nil) } - want := []LogEntry{ + want := []logEntry{ {ID: 8, Question: "baz.example.com", Qtype: 28, Time: 1560647100, RemoteAddr: net.IPv4(192, 0, 2, 102)}, {ID: 7, Question: "baz.example.com", Qtype: 28, Answer: "2001:db8::4", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, {ID: 6, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::3", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}, @@ -138,7 +138,7 @@ func TestDeleteLogBefore(t *testing.T) { {ID: 3, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637050, RemoteAddr: net.IPv4(192, 0, 2, 101)}, } n := 10 - got, err := c.ReadLog(n) + got, err := c.readLog(n) if err != nil || !reflect.DeepEqual(got, want) { t.Errorf("ReadLog(%d) = (%+v, %v), want (%+v, %v)", n, got, err, want, nil) } @@ -155,7 +155,7 @@ func TestDeleteLogBefore(t *testing.T) { // Delete logs in the far past which matches 0 entries. oneYear := time.Hour * 8760 - if err := c.DeleteLogBefore(u.Add(-oneYear)); err != nil { + if err := c.deleteLogBefore(u.Add(-oneYear)); err != nil { t.Fatal(err) } } @@ -169,12 +169,12 @@ func TestInterleavedRW(t *testing.T) { go func() { defer wg.Done() for range ch { - err = c.WriteLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1") + err = c.writeLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1") } }() ch <- true close(ch) - if _, err := c.ReadLog(1); err != nil { + if _, err := c.readLog(1); err != nil { t.Fatal(err) } wg.Wait() |