diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-04-13 15:21:18 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-04-17 10:31:40 +0200 |
commit | 96d8aae1ec9b4f6130b6b610ce23d2bbdb79298a (patch) | |
tree | f482eaa488eb5d5925b49d665b29c07ab516ef7f /client/go/internal/cli/auth/zts | |
parent | cce3b08cbe1864e80d5b9e57891622706b1d8181 (diff) |
Support TLS in custom target
Diffstat (limited to 'client/go/internal/cli/auth/zts')
-rw-r--r-- | client/go/internal/cli/auth/zts/zts.go | 28 | ||||
-rw-r--r-- | client/go/internal/cli/auth/zts/zts_test.go | 7 |
2 files changed, 23 insertions, 12 deletions
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go index caa2d03367d..2c66ff13e8b 100644 --- a/client/go/internal/cli/auth/zts/zts.go +++ b/client/go/internal/cli/auth/zts/zts.go @@ -1,7 +1,6 @@ package zts import ( - "crypto/tls" "encoding/json" "fmt" "net/http" @@ -18,26 +17,39 @@ const DefaultURL = "https://zts.athenz.ouroath.com:4443" type Client struct { client util.HTTPClient tokenURL *url.URL + domain string } // NewClient creates a new client for an Athenz ZTS service located at serviceURL. -func NewClient(client util.HTTPClient, serviceURL string) (*Client, error) { +func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) { tokenURL, err := url.Parse(serviceURL) if err != nil { return nil, err } tokenURL.Path = "/zts/v1/oauth2/token" - return &Client{tokenURL: tokenURL, client: client}, nil + return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil } -// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS. -func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) { - data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain) +func (c *Client) Authenticate(request *http.Request) error { + accessToken, err := c.AccessToken() + if err != nil { + return err + } + if request.Header == nil { + request.Header = make(http.Header) + } + request.Header.Add("Authorization", "Bearer "+accessToken) + return nil +} + +// AccessToken returns an access token within the domain configured in client c. +func (c *Client) AccessToken() (string, error) { + // TODO(mpolden): This should cache and re-use tokens until expiry + data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain) req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data)) if err != nil { return "", err } - util.SetCertificates(c.client, []tls.Certificate{certificate}) response, err := c.client.Do(req, 10*time.Second) if err != nil { return "", err @@ -45,7 +57,7 @@ func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string defer response.Body.Close() if response.StatusCode != http.StatusOK { - return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String()) + return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String()) } var ztsResponse struct { AccessToken string `json:"access_token"` diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go index d0cc7ea9f9d..1c75a94ee03 100644 --- a/client/go/internal/cli/auth/zts/zts_test.go +++ b/client/go/internal/cli/auth/zts/zts_test.go @@ -1,7 +1,6 @@ package zts import ( - "crypto/tls" "testing" "github.com/vespa-engine/vespa/client/go/internal/mock" @@ -9,17 +8,17 @@ import ( func TestAccessToken(t *testing.T) { httpClient := mock.HTTPClient{} - client, err := NewClient(&httpClient, "http://example.com") + client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com") if err != nil { t.Fatal(err) } httpClient.NextResponseString(400, `{"message": "bad request"}`) - _, err = client.AccessToken("vespa.vespa", tls.Certificate{}) + _, err = client.AccessToken() if err == nil { t.Fatal("want error for non-ok response status") } httpClient.NextResponseString(200, `{"access_token": "foo bar"}`) - token, err := client.AccessToken("vespa.vespa", tls.Certificate{}) + token, err := client.AccessToken() if err != nil { t.Fatal(err) } |