aboutsummaryrefslogtreecommitdiffstats
path: root/sql
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2020-01-11 23:19:22 +0100
committerMartin Polden <mpolden@mpolden.no>2020-01-11 23:19:22 +0100
commit0281da919acf94402047fd2e185a82b52fad1c23 (patch)
treee879eece3a96b14220453f5401a4e355a18600cb /sql
parent644b439d17a92d16c06e298148b0a69599205ef3 (diff)
Reduce visibility of database methods
Diffstat (limited to 'sql')
-rw-r--r--sql/logger.go36
-rw-r--r--sql/logger_test.go14
-rw-r--r--sql/sql.go14
-rw-r--r--sql/sql_test.go40
4 files changed, 50 insertions, 54 deletions
diff --git a/sql/logger.go b/sql/logger.go
index 4902011..0c2fd6a 100644
--- a/sql/logger.go
+++ b/sql/logger.go
@@ -16,17 +16,17 @@ const (
LogHijacked
)
-// Logger is a logs DNS requests to a SQL database.
+// Logger is a logger that logs DNS requests to a SQL database.
type Logger struct {
mode int
- queue chan Entry
+ queue chan LogEntry
client *Client
wg sync.WaitGroup
now func() time.Time
}
-// Entry represents a DNS request log entry.
-type Entry struct {
+// LogEntry represents a log entry for a DNS request.
+type LogEntry struct {
Time time.Time
RemoteAddr net.IP
Hijacked bool
@@ -39,7 +39,7 @@ type Entry struct {
func NewLogger(client *Client, mode int, ttl time.Duration) *Logger {
l := &Logger{
client: client,
- queue: make(chan Entry, 1024),
+ queue: make(chan LogEntry, 1024),
now: time.Now,
mode: mode,
}
@@ -64,7 +64,7 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question
return
}
l.wg.Add(1)
- l.queue <- Entry{
+ l.queue <- LogEntry{
Time: l.now(),
RemoteAddr: remoteAddr,
Hijacked: hijacked,
@@ -74,43 +74,43 @@ func (l *Logger) Record(remoteAddr net.IP, hijacked bool, qtype uint16, question
}
}
-// Get returns the n most recent persisted log entries.
-func (l *Logger) Get(n int) ([]Entry, error) {
- logEntries, err := l.client.ReadLog(n)
+// Read returns the n most recent log entries.
+func (l *Logger) Read(n int) ([]LogEntry, error) {
+ entries, err := l.client.readLog(n)
if err != nil {
return nil, err
}
- ids := make(map[int64]*Entry)
- entries := make([]Entry, 0, len(logEntries))
- for _, le := range logEntries {
+ ids := make(map[int64]*LogEntry)
+ logEntries := make([]LogEntry, 0, len(entries))
+ for _, le := range entries {
entry, ok := ids[le.ID]
if !ok {
- newEntry := Entry{
+ newEntry := LogEntry{
Time: time.Unix(le.Time, 0).UTC(),
RemoteAddr: le.RemoteAddr,
Hijacked: le.Hijacked,
Qtype: le.Qtype,
Question: le.Question,
}
- entries = append(entries, newEntry)
- entry = &entries[len(entries)-1]
+ logEntries = append(logEntries, newEntry)
+ entry = &logEntries[len(logEntries)-1]
ids[le.ID] = entry
}
if le.Answer != "" {
entry.Answers = append(entry.Answers, le.Answer)
}
}
- return entries, nil
+ return logEntries, nil
}
func (l *Logger) readQueue(ttl time.Duration) {
for e := range l.queue {
- if err := l.client.WriteLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
+ if err := l.client.writeLog(e.Time, e.RemoteAddr, e.Hijacked, e.Qtype, e.Question, e.Answers...); err != nil {
log.Printf("write failed: %+v: %s", e, err)
}
if ttl > 0 {
t := l.now().Add(-ttl)
- if err := l.client.DeleteLogBefore(t); err != nil {
+ if err := l.client.deleteLogBefore(t); err != nil {
log.Printf("deleting log entries before %v failed: %s", t, err)
}
}
diff --git a/sql/logger_test.go b/sql/logger_test.go
index ab6c94b..1d189c8 100644
--- a/sql/logger_test.go
+++ b/sql/logger_test.go
@@ -15,7 +15,7 @@ func TestRecord(t *testing.T) {
if err := logger.Close(); err != nil {
t.Fatal(err)
}
- logEntries, err := logger.client.ReadLog(1)
+ logEntries, err := logger.client.readLog(1)
if err != nil {
t.Fatal(err)
}
@@ -48,7 +48,7 @@ func TestMode(t *testing.T) {
if err := logger.Close(); err != nil { // Flush
t.Fatal(err)
}
- entries, err := logger.Get(1)
+ entries, err := logger.Read(1)
if err != nil {
t.Fatal(err)
}
@@ -69,11 +69,11 @@ func TestAnswerMerging(t *testing.T) {
t.Fatal(err)
}
// Multi-answer log entries are merged
- got, err := logger.Get(2)
+ got, err := logger.Read(2)
if err != nil {
t.Fatal(err)
}
- want := []Entry{
+ want := []LogEntry{
{
Time: now,
RemoteAddr: net.IPv4(192, 0, 2, 100),
@@ -103,10 +103,10 @@ func TestLogPruning(t *testing.T) {
// Wait until queue is flushed
ts := time.Now()
- var entries []Entry
+ var entries []LogEntry
var err error
for len(entries) == 0 {
- entries, err = logger.Get(1)
+ entries, err = logger.Read(1)
if err != nil {
t.Fatal(err)
}
@@ -121,7 +121,7 @@ func TestLogPruning(t *testing.T) {
// Trigger pruning by recording another entry
logger.Record(net.IPv4(192, 0, 2, 100), false, 1, "2.example.com.", "192.0.2.2")
for len(entries) > 1 {
- entries, err = logger.Get(2)
+ entries, err = logger.Read(2)
if err != nil {
t.Fatal(err)
}
diff --git a/sql/sql.go b/sql/sql.go
index ac39e09..5bc07e4 100644
--- a/sql/sql.go
+++ b/sql/sql.go
@@ -68,8 +68,7 @@ type Client struct {
mu sync.RWMutex
}
-// LogEntry represents an entry in the log.
-type LogEntry struct {
+type logEntry struct {
ID int64 `db:"id"`
Time int64 `db:"time"`
RemoteAddr []byte `db:"remote_addr"`
@@ -100,8 +99,7 @@ func New(filename string) (*Client, error) {
return &Client{db: db}, nil
}
-// ReadLog reads the n most recent entries from the log.
-func (c *Client) ReadLog(n int) ([]LogEntry, error) {
+func (c *Client) readLog(n int) ([]logEntry, error) {
c.mu.RLock()
defer c.mu.RUnlock()
query := `
@@ -121,7 +119,7 @@ LEFT JOIN rr_answer ON rr_answer.id = log_rr_answer.rr_answer_id
WHERE log.id IN (SELECT id FROM log ORDER BY time DESC, id DESC LIMIT $1)
ORDER BY time DESC, rr_answer.id DESC
`
- var entries []LogEntry
+ var entries []logEntry
err := c.db.Select(&entries, query, n)
return entries, err
}
@@ -139,8 +137,7 @@ func getOrInsert(tx *sqlx.Tx, table, column string, value interface{}) (int64, e
return id, err
}
-// WriteLog writes a new entry to the log.
-func (c *Client) WriteLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error {
+func (c *Client) writeLog(time time.Time, remoteAddr []byte, hijacked bool, qtype uint16, question string, answers ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Beginx()
@@ -188,8 +185,7 @@ func (c *Client) WriteLog(time time.Time, remoteAddr []byte, hijacked bool, qtyp
return tx.Commit()
}
-// DeleteLogBefore deletes all log entries occurring before time t.
-func (c *Client) DeleteLogBefore(t time.Time) (err error) {
+func (c *Client) deleteLogBefore(t time.Time) (err error) {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Beginx()
diff --git a/sql/sql_test.go b/sql/sql_test.go
index efa9ba7..6f8bbbe 100644
--- a/sql/sql_test.go
+++ b/sql/sql_test.go
@@ -58,10 +58,18 @@ func count(t *testing.T, client *Client, query string, args ...interface{}) int
return rows
}
+func writeTests(c *Client, t *testing.T) {
+ for i, tt := range tests {
+ if err := c.writeLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
+ t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err)
+ }
+ }
+}
+
func TestWriteLog(t *testing.T) {
c := testClient()
for i, tt := range tests {
- if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
+ if err := c.writeLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
t.Errorf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err)
}
for _, rowCount := range tt.rowCounts {
@@ -75,12 +83,8 @@ func TestWriteLog(t *testing.T) {
func TestReadLog(t *testing.T) {
c := testClient()
- for i, tt := range tests {
- if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
- t.Fatalf("#%d: WriteLog(%q, %s, %t, %d, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.qtype, tt.question, tt.answers, err)
- }
- }
- allEntries := [][]LogEntry{
+ writeTests(c, t)
+ allEntries := [][]logEntry{
{{ID: 8, Question: "baz.example.com", Qtype: 28, Time: 1560647100, RemoteAddr: net.IPv4(192, 0, 2, 102)}},
{{ID: 7, Question: "baz.example.com", Qtype: 28, Answer: "2001:db8::4", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)}},
{
@@ -94,11 +98,11 @@ func TestReadLog(t *testing.T) {
{{ID: 1, Question: "foo.example.com", Qtype: 1, Answer: "192.0.2.1", Time: 1560636910, RemoteAddr: net.IPv4(192, 0, 2, 100)}},
}
for n := 1; n <= len(allEntries); n++ {
- var want []LogEntry
+ var want []logEntry
for _, entries := range allEntries[:n] {
want = append(want, entries...)
}
- got, err := c.ReadLog(n)
+ got, err := c.readLog(n)
if len(got) != len(want) {
t.Errorf("len(got) = %d, want %d", len(got), len(want))
}
@@ -118,17 +122,13 @@ func TestReadLog(t *testing.T) {
func TestDeleteLogBefore(t *testing.T) {
c := testClient()
- for i, tt := range tests {
- if err := c.WriteLog(tt.t, tt.remoteAddr, tt.hijacked, tt.qtype, tt.question, tt.answers...); err != nil {
- t.Fatalf("#%d: WriteLog(%s, %s, %t, %q, %q) = %s, want nil", i, tt.t, tt.remoteAddr.String(), tt.hijacked, tt.question, tt.answers, err)
- }
- }
+ writeTests(c, t)
u := tests[1].t.Add(time.Second)
- if err := c.DeleteLogBefore(u); err != nil {
+ if err := c.deleteLogBefore(u); err != nil {
t.Fatalf("DeleteBefore(%s) = %v, want %v", u, err, nil)
}
- want := []LogEntry{
+ want := []logEntry{
{ID: 8, Question: "baz.example.com", Qtype: 28, Time: 1560647100, RemoteAddr: net.IPv4(192, 0, 2, 102)},
{ID: 7, Question: "baz.example.com", Qtype: 28, Answer: "2001:db8::4", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)},
{ID: 6, Question: "bar.example.com", Qtype: 28, Answer: "2001:db8::3", Time: 1560641700, RemoteAddr: net.IPv4(192, 0, 2, 102)},
@@ -138,7 +138,7 @@ func TestDeleteLogBefore(t *testing.T) {
{ID: 3, Question: "bar.example.com", Qtype: 1, Answer: "192.0.2.2", Time: 1560637050, RemoteAddr: net.IPv4(192, 0, 2, 101)},
}
n := 10
- got, err := c.ReadLog(n)
+ got, err := c.readLog(n)
if err != nil || !reflect.DeepEqual(got, want) {
t.Errorf("ReadLog(%d) = (%+v, %v), want (%+v, %v)", n, got, err, want, nil)
}
@@ -155,7 +155,7 @@ func TestDeleteLogBefore(t *testing.T) {
// Delete logs in the far past which matches 0 entries.
oneYear := time.Hour * 8760
- if err := c.DeleteLogBefore(u.Add(-oneYear)); err != nil {
+ if err := c.deleteLogBefore(u.Add(-oneYear)); err != nil {
t.Fatal(err)
}
}
@@ -169,12 +169,12 @@ func TestInterleavedRW(t *testing.T) {
go func() {
defer wg.Done()
for range ch {
- err = c.WriteLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1")
+ err = c.writeLog(time.Now(), net.IPv4(127, 0, 0, 1), false, 1, "example.com.", "192.0.2.1")
}
}()
ch <- true
close(ch)
- if _, err := c.ReadLog(1); err != nil {
+ if _, err := c.readLog(1); err != nil {
t.Fatal(err)
}
wg.Wait()