aboutsummaryrefslogtreecommitdiffstats
path: root/client/go/internal/cli/auth/zts
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-04-13 15:21:18 +0200
committerMartin Polden <mpolden@mpolden.no>2023-04-17 10:31:40 +0200
commit96d8aae1ec9b4f6130b6b610ce23d2bbdb79298a (patch)
treef482eaa488eb5d5925b49d665b29c07ab516ef7f /client/go/internal/cli/auth/zts
parentcce3b08cbe1864e80d5b9e57891622706b1d8181 (diff)
Support TLS in custom target
Diffstat (limited to 'client/go/internal/cli/auth/zts')
-rw-r--r--client/go/internal/cli/auth/zts/zts.go28
-rw-r--r--client/go/internal/cli/auth/zts/zts_test.go7
2 files changed, 23 insertions, 12 deletions
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go
index caa2d03367d..2c66ff13e8b 100644
--- a/client/go/internal/cli/auth/zts/zts.go
+++ b/client/go/internal/cli/auth/zts/zts.go
@@ -1,7 +1,6 @@
package zts
import (
- "crypto/tls"
"encoding/json"
"fmt"
"net/http"
@@ -18,26 +17,39 @@ const DefaultURL = "https://zts.athenz.ouroath.com:4443"
type Client struct {
client util.HTTPClient
tokenURL *url.URL
+ domain string
}
// NewClient creates a new client for an Athenz ZTS service located at serviceURL.
-func NewClient(client util.HTTPClient, serviceURL string) (*Client, error) {
+func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) {
tokenURL, err := url.Parse(serviceURL)
if err != nil {
return nil, err
}
tokenURL.Path = "/zts/v1/oauth2/token"
- return &Client{tokenURL: tokenURL, client: client}, nil
+ return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil
}
-// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS.
-func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) {
- data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain)
+func (c *Client) Authenticate(request *http.Request) error {
+ accessToken, err := c.AccessToken()
+ if err != nil {
+ return err
+ }
+ if request.Header == nil {
+ request.Header = make(http.Header)
+ }
+ request.Header.Add("Authorization", "Bearer "+accessToken)
+ return nil
+}
+
+// AccessToken returns an access token within the domain configured in client c.
+func (c *Client) AccessToken() (string, error) {
+ // TODO(mpolden): This should cache and re-use tokens until expiry
+ data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain)
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data))
if err != nil {
return "", err
}
- util.SetCertificates(c.client, []tls.Certificate{certificate})
response, err := c.client.Do(req, 10*time.Second)
if err != nil {
return "", err
@@ -45,7 +57,7 @@ func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
- return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String())
+ return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
}
var ztsResponse struct {
AccessToken string `json:"access_token"`
diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go
index d0cc7ea9f9d..1c75a94ee03 100644
--- a/client/go/internal/cli/auth/zts/zts_test.go
+++ b/client/go/internal/cli/auth/zts/zts_test.go
@@ -1,7 +1,6 @@
package zts
import (
- "crypto/tls"
"testing"
"github.com/vespa-engine/vespa/client/go/internal/mock"
@@ -9,17 +8,17 @@ import (
func TestAccessToken(t *testing.T) {
httpClient := mock.HTTPClient{}
- client, err := NewClient(&httpClient, "http://example.com")
+ client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com")
if err != nil {
t.Fatal(err)
}
httpClient.NextResponseString(400, `{"message": "bad request"}`)
- _, err = client.AccessToken("vespa.vespa", tls.Certificate{})
+ _, err = client.AccessToken()
if err == nil {
t.Fatal("want error for non-ok response status")
}
httpClient.NextResponseString(200, `{"access_token": "foo bar"}`)
- token, err := client.AccessToken("vespa.vespa", tls.Certificate{})
+ token, err := client.AccessToken()
if err != nil {
t.Fatal(err)
}