summaryrefslogtreecommitdiffstats
path: root/client/go/internal/cli
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-05-02 15:10:10 +0200
committerMartin Polden <mpolden@mpolden.no>2023-05-02 15:21:04 +0200
commit22cdb33a8ef8f235d0c63935bc9055b9cc7f8cc4 (patch)
tree4a9ee48c8da55f754e469688ce9d92f12ab773e6 /client/go/internal/cli
parent6ce4ab0f38389d93d8b057afc07b55f603a8c44f (diff)
Cache ZTS access token until expiry
Diffstat (limited to 'client/go/internal/cli')
-rw-r--r--client/go/internal/cli/auth/zts/zts.go62
-rw-r--r--client/go/internal/cli/auth/zts/zts_test.go37
2 files changed, 80 insertions, 19 deletions
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go
index 2c66ff13e8b..0f73ea5912d 100644
--- a/client/go/internal/cli/auth/zts/zts.go
+++ b/client/go/internal/cli/auth/zts/zts.go
@@ -3,23 +3,39 @@ package zts
import (
"encoding/json"
"fmt"
+ "io"
"net/http"
"net/url"
"strings"
+ "sync"
"time"
"github.com/vespa-engine/vespa/client/go/internal/util"
)
-const DefaultURL = "https://zts.athenz.ouroath.com:4443"
+const (
+ DefaultURL = "https://zts.athenz.ouroath.com:4443"
+ expirySlack = 5 * time.Minute
+)
// Client is a client for Athenz ZTS, an authentication token service.
type Client struct {
client util.HTTPClient
tokenURL *url.URL
domain string
+ now func() time.Time
+ token Token
+ mu sync.Mutex
+}
+
+// Token is an access token retrieved from ZTS.
+type Token struct {
+ Value string
+ ExpiresAt time.Time
}
+func (t *Token) isExpired(now time.Time) bool { return t.ExpiresAt.Sub(now) < expirySlack }
+
// NewClient creates a new client for an Athenz ZTS service located at serviceURL.
func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) {
tokenURL, err := url.Parse(serviceURL)
@@ -27,44 +43,60 @@ func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, erro
return nil, err
}
tokenURL.Path = "/zts/v1/oauth2/token"
- return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil
+ return &Client{tokenURL: tokenURL, client: client, domain: domain, now: time.Now}, nil
}
func (c *Client) Authenticate(request *http.Request) error {
- accessToken, err := c.AccessToken()
- if err != nil {
- return err
+ now := c.now()
+ if c.token.isExpired(now) {
+ c.mu.Lock()
+ if c.token.isExpired(now) {
+ accessToken, err := c.AccessToken()
+ if err != nil {
+ c.mu.Unlock()
+ return err
+ }
+ c.token = accessToken
+ }
+ c.mu.Unlock()
}
if request.Header == nil {
request.Header = make(http.Header)
}
- request.Header.Add("Authorization", "Bearer "+accessToken)
+ request.Header.Add("Authorization", "Bearer "+c.token.Value)
return nil
}
// AccessToken returns an access token within the domain configured in client c.
-func (c *Client) AccessToken() (string, error) {
- // TODO(mpolden): This should cache and re-use tokens until expiry
+func (c *Client) AccessToken() (Token, error) {
data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain)
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data))
if err != nil {
- return "", err
+ return Token{}, err
}
+ now := c.now()
response, err := c.client.Do(req, 10*time.Second)
if err != nil {
- return "", err
+ return Token{}, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
- return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
+ return Token{}, fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
}
var ztsResponse struct {
AccessToken string `json:"access_token"`
+ ExpirySecs int `json:"expires_in"`
+ }
+ b, err := io.ReadAll(response.Body)
+ if err != nil {
+ return Token{}, err
}
- dec := json.NewDecoder(response.Body)
- if err := dec.Decode(&ztsResponse); err != nil {
- return "", err
+ if err := json.Unmarshal(b, &ztsResponse); err != nil {
+ return Token{}, err
}
- return ztsResponse.AccessToken, nil
+ return Token{
+ Value: ztsResponse.AccessToken,
+ ExpiresAt: now.Add(time.Duration(ztsResponse.ExpirySecs) * time.Second),
+ }, nil
}
diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go
index 1c75a94ee03..15c60ed46d7 100644
--- a/client/go/internal/cli/auth/zts/zts_test.go
+++ b/client/go/internal/cli/auth/zts/zts_test.go
@@ -2,28 +2,57 @@ package zts
import (
"testing"
+ "time"
"github.com/vespa-engine/vespa/client/go/internal/mock"
)
+type manualClock struct{ t time.Time }
+
+func (c *manualClock) now() time.Time { return c.t }
+func (c *manualClock) advance(d time.Duration) { c.t = c.t.Add(d) }
+
func TestAccessToken(t *testing.T) {
httpClient := mock.HTTPClient{}
client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com")
if err != nil {
t.Fatal(err)
}
+ clock := &manualClock{t: time.Now()}
+ client.now = clock.now
httpClient.NextResponseString(400, `{"message": "bad request"}`)
_, err = client.AccessToken()
if err == nil {
t.Fatal("want error for non-ok response status")
}
- httpClient.NextResponseString(200, `{"access_token": "foo bar"}`)
+ httpClient.NextResponseString(200, `{"access_token": "foo", "expires_in": 3600}`)
token, err := client.AccessToken()
if err != nil {
t.Fatal(err)
}
- want := "foo bar"
- if token != want {
- t.Errorf("got %q, want %q", token, want)
+
+ // Token is cached
+ expiresAt := clock.now().Add(time.Hour)
+ assertToken(t, Token{Value: "foo", ExpiresAt: expiresAt}, token)
+ clock.advance(54 * time.Minute)
+ assertToken(t, Token{Value: "foo", ExpiresAt: expiresAt}, token)
+
+ // Token is renewed when nearing expiry
+ clock.advance(time.Minute + time.Second)
+ httpClient.NextResponseString(200, `{"access_token": "bar", "expires_in": 1800}`)
+ token, err = client.AccessToken()
+ if err != nil {
+ t.Fatal(err)
+ }
+ expiresAt = clock.now().Add(30 * time.Minute)
+ assertToken(t, Token{Value: "bar", ExpiresAt: expiresAt}, token)
+}
+
+func assertToken(t *testing.T, want, got Token) {
+ if want.Value != got.Value {
+ t.Errorf("got Value=%q, want %q", got.Value, want.Value)
+ }
+ if want.ExpiresAt != got.ExpiresAt {
+ t.Errorf("got ExpiresAt=%s, want %s", got.ExpiresAt, want.ExpiresAt)
}
}