aboutsummaryrefslogtreecommitdiffstats
path: root/http/http_test.go
blob: e964e3b16335f876e7104c40d2b9b04624a625bc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package http

import (
	"io/ioutil"
	"net"
	"net/http"
	"net/http/httptest"
	"regexp"
	"strings"
	"testing"

	"github.com/miekg/dns"
	"github.com/mpolden/zdns/cache"
	"github.com/mpolden/zdns/log"
)

func newA(name string, ttl uint32, ipAddr ...net.IP) *dns.Msg {
	m := dns.Msg{}
	m.Id = dns.Id()
	m.SetQuestion(dns.Fqdn(name), dns.TypeA)
	rr := make([]dns.RR, 0, len(ipAddr))
	for _, ip := range ipAddr {
		rr = append(rr, &dns.A{
			A:   ip,
			Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
		})
	}
	m.Answer = rr
	return &m
}

func testServer() (*httptest.Server, *Server) {
	logger, err := log.New(ioutil.Discard, "", log.RecordOptions{Database: ":memory:"})
	if err != nil {
		panic(err)
	}
	cache := cache.New(10)
	server := Server{logger: logger, cache: cache}
	return httptest.NewServer(server.handler()), &server
}

func httpGet(url string) (string, int, error) {
	res, err := http.Get(url)
	if err != nil {
		return "", 0, err
	}
	defer res.Body.Close()
	data, err := ioutil.ReadAll(res.Body)
	if err != nil {
		return "", 0, err
	}
	return string(data), res.StatusCode, nil
}

func TestRequests(t *testing.T) {
	httpSrv, srv := testServer()
	defer httpSrv.Close()
	srv.logger.Record(net.IPv4(127, 0, 0, 42), false, 1, "example.com.", "192.0.2.100", "192.0.2.101")
	srv.logger.Record(net.IPv4(127, 0, 0, 254), true, 28, "example.com.", "2001:db8::1")
	srv.cache.Set(1, newA("1.example.com.", 60, net.IPv4(192, 0, 2, 200)))
	srv.cache.Set(2, newA("2.example.com.", 30, net.IPv4(192, 0, 2, 201)))

	cr1 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"},` +
		`{"time":"RFC3339","ttl":60,"type":"A","question":"1.example.com.","answers":["192.0.2.200"],"rcode":"NOERROR"}]`
	cr2 := `[{"time":"RFC3339","ttl":30,"type":"A","question":"2.example.com.","answers":["192.0.2.201"],"rcode":"NOERROR"}]`
	lr1 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]},` +
		`{"time":"RFC3339","remote_addr":"127.0.0.42","hijacked":false,"type":"A","question":"example.com.","answers":["192.0.2.101","192.0.2.100"]}]`
	lr2 := `[{"time":"RFC3339","remote_addr":"127.0.0.254","hijacked":true,"type":"AAAA","question":"example.com.","answers":["2001:db8::1"]}]`

	var tests = []struct {
		method   string
		url      string
		response string
		status   int
	}{
		{http.MethodGet, "/not-found", `{"status":404,"message":"Resource not found"}`, 404},
		{http.MethodGet, "/log/v1/", lr1, 200},
		{http.MethodGet, "/log/v1/?n=foo", lr1, 200},
		{http.MethodGet, "/log/v1/?n=1", lr2, 200},
		{http.MethodGet, "/cache/v1/", cr1, 200},
		{http.MethodGet, "/cache/v1/?n=foo", cr1, 200},
		{http.MethodGet, "/cache/v1/?n=1", cr2, 200},
	}

	for i, tt := range tests {
		var (
			data   string
			status int
			err    error
		)
		switch tt.method {
		case http.MethodGet:
			data, status, err = httpGet(httpSrv.URL + tt.url)
		default:
			t.Fatalf("#%d: invalid method: %s", i, tt.method)
		}
		if err != nil {
			t.Fatal(err)
		}
		if got := status; status != tt.status {
			t.Errorf("#%d: %s %s returned status %d, want %d", i, tt.method, tt.url, got, tt.status)
		}

		got := string(data)
		want := regexp.QuoteMeta(tt.response)
		want = strings.ReplaceAll(want, "RFC3339", `\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z`)
		matched, err := regexp.MatchString(want, got)
		if err != nil {
			t.Fatal(err)
		}
		if !matched {
			t.Errorf("#%d: %s %s returned response %s, want %s", i, tt.method, tt.url, got, want)
		}
	}
}