diff options
-rw-r--r-- | client/go/internal/cli/auth/zts/zts.go | 62 | ||||
-rw-r--r-- | client/go/internal/cli/auth/zts/zts_test.go | 37 | ||||
-rw-r--r-- | client/go/internal/cli/cmd/feed.go | 33 | ||||
-rw-r--r-- | client/go/internal/cli/cmd/feed_test.go | 3 | ||||
-rw-r--r-- | client/go/internal/cli/cmd/root.go | 27 | ||||
-rw-r--r-- | client/go/internal/cli/cmd/testutil_test.go | 2 | ||||
-rw-r--r-- | client/go/internal/util/http.go | 3 | ||||
-rw-r--r-- | client/go/internal/vespa/target.go | 3 |
8 files changed, 124 insertions, 46 deletions
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go index 2c66ff13e8b..0f73ea5912d 100644 --- a/client/go/internal/cli/auth/zts/zts.go +++ b/client/go/internal/cli/auth/zts/zts.go @@ -3,23 +3,39 @@ package zts import ( "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" + "sync" "time" "github.com/vespa-engine/vespa/client/go/internal/util" ) -const DefaultURL = "https://zts.athenz.ouroath.com:4443" +const ( + DefaultURL = "https://zts.athenz.ouroath.com:4443" + expirySlack = 5 * time.Minute +) // Client is a client for Athenz ZTS, an authentication token service. type Client struct { client util.HTTPClient tokenURL *url.URL domain string + now func() time.Time + token Token + mu sync.Mutex +} + +// Token is an access token retrieved from ZTS. +type Token struct { + Value string + ExpiresAt time.Time } +func (t *Token) isExpired(now time.Time) bool { return t.ExpiresAt.Sub(now) < expirySlack } + // NewClient creates a new client for an Athenz ZTS service located at serviceURL. func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) { tokenURL, err := url.Parse(serviceURL) @@ -27,44 +43,60 @@ func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, erro return nil, err } tokenURL.Path = "/zts/v1/oauth2/token" - return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil + return &Client{tokenURL: tokenURL, client: client, domain: domain, now: time.Now}, nil } func (c *Client) Authenticate(request *http.Request) error { - accessToken, err := c.AccessToken() - if err != nil { - return err + now := c.now() + if c.token.isExpired(now) { + c.mu.Lock() + if c.token.isExpired(now) { + accessToken, err := c.AccessToken() + if err != nil { + c.mu.Unlock() + return err + } + c.token = accessToken + } + c.mu.Unlock() } if request.Header == nil { request.Header = make(http.Header) } - request.Header.Add("Authorization", "Bearer "+accessToken) + request.Header.Add("Authorization", "Bearer "+c.token.Value) return nil } // AccessToken returns an access token within the domain configured in client c. -func (c *Client) AccessToken() (string, error) { - // TODO(mpolden): This should cache and re-use tokens until expiry +func (c *Client) AccessToken() (Token, error) { data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain) req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data)) if err != nil { - return "", err + return Token{}, err } + now := c.now() response, err := c.client.Do(req, 10*time.Second) if err != nil { - return "", err + return Token{}, err } defer response.Body.Close() if response.StatusCode != http.StatusOK { - return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String()) + return Token{}, fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String()) } var ztsResponse struct { AccessToken string `json:"access_token"` + ExpirySecs int `json:"expires_in"` + } + b, err := io.ReadAll(response.Body) + if err != nil { + return Token{}, err } - dec := json.NewDecoder(response.Body) - if err := dec.Decode(&ztsResponse); err != nil { - return "", err + if err := json.Unmarshal(b, &ztsResponse); err != nil { + return Token{}, err } - return ztsResponse.AccessToken, nil + return Token{ + Value: ztsResponse.AccessToken, + ExpiresAt: now.Add(time.Duration(ztsResponse.ExpirySecs) * time.Second), + }, nil } diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go index 1c75a94ee03..15c60ed46d7 100644 --- a/client/go/internal/cli/auth/zts/zts_test.go +++ b/client/go/internal/cli/auth/zts/zts_test.go @@ -2,28 +2,57 @@ package zts import ( "testing" + "time" "github.com/vespa-engine/vespa/client/go/internal/mock" ) +type manualClock struct{ t time.Time } + +func (c *manualClock) now() time.Time { return c.t } +func (c *manualClock) advance(d time.Duration) { c.t = c.t.Add(d) } + func TestAccessToken(t *testing.T) { httpClient := mock.HTTPClient{} client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com") if err != nil { t.Fatal(err) } + clock := &manualClock{t: time.Now()} + client.now = clock.now httpClient.NextResponseString(400, `{"message": "bad request"}`) _, err = client.AccessToken() if err == nil { t.Fatal("want error for non-ok response status") } - httpClient.NextResponseString(200, `{"access_token": "foo bar"}`) + httpClient.NextResponseString(200, `{"access_token": "foo", "expires_in": 3600}`) token, err := client.AccessToken() if err != nil { t.Fatal(err) } - want := "foo bar" - if token != want { - t.Errorf("got %q, want %q", token, want) + + // Token is cached + expiresAt := clock.now().Add(time.Hour) + assertToken(t, Token{Value: "foo", ExpiresAt: expiresAt}, token) + clock.advance(54 * time.Minute) + assertToken(t, Token{Value: "foo", ExpiresAt: expiresAt}, token) + + // Token is renewed when nearing expiry + clock.advance(time.Minute + time.Second) + httpClient.NextResponseString(200, `{"access_token": "bar", "expires_in": 1800}`) + token, err = client.AccessToken() + if err != nil { + t.Fatal(err) + } + expiresAt = clock.now().Add(30 * time.Minute) + assertToken(t, Token{Value: "bar", ExpiresAt: expiresAt}, token) +} + +func assertToken(t *testing.T, want, got Token) { + if want.Value != got.Value { + t.Errorf("got Value=%q, want %q", got.Value, want.Value) + } + if want.ExpiresAt != got.ExpiresAt { + t.Errorf("got ExpiresAt=%s, want %s", got.ExpiresAt, want.ExpiresAt) } } diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index 5b168ef79a2..2ea3ee1c4ed 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -93,15 +93,30 @@ $ cat docs.jsonl | vespa feed -`, return cmd } -func createServiceClients(service *vespa.Service, n int) []util.HTTPClient { - clients := make([]util.HTTPClient, 0, n) +func createServices(n int, timeout time.Duration, cli *CLI) ([]util.HTTPClient, string, error) { + if n < 1 { + return nil, "", fmt.Errorf("need at least one client") + } + target, err := cli.target(targetOptions{}) + if err != nil { + return nil, "", err + } + services := make([]util.HTTPClient, 0, n) + baseURL := "" for i := 0; i < n; i++ { - client := service.Client().Clone() + service, err := cli.service(target, vespa.DocumentService, 0, cli.config.cluster()) + if err != nil { + return nil, "", err + } + baseURL = service.BaseURL + // Create a separate HTTP client for each service + client := cli.httpClientFactory(timeout) // Feeding should always use HTTP/2 util.ForceHTTP2(client, service.TLSOptions.KeyPair, service.TLSOptions.CACertificate, service.TLSOptions.TrustAll) - clients = append(clients, client) + service.SetClient(client) + services = append(services, service) } - return clients + return services, baseURL, nil } func summaryTicker(secs int, cli *CLI, start time.Time, statsFunc func() document.Stats) *time.Ticker { @@ -130,21 +145,21 @@ func (opts feedOptions) compressionMode() (document.Compression, error) { } func feed(files []string, options feedOptions, cli *CLI) error { - service, err := documentService(cli) + timeout := time.Duration(options.timeoutSecs) * time.Second + clients, baseURL, err := createServices(options.connections, timeout, cli) if err != nil { return err } - clients := createServiceClients(service, options.connections) compression, err := options.compressionMode() if err != nil { return err } client, err := document.NewClient(document.ClientOptions{ Compression: compression, - Timeout: time.Duration(options.timeoutSecs) * time.Second, + Timeout: timeout, Route: options.route, TraceLevel: options.traceLevel, - BaseURL: service.BaseURL, + BaseURL: baseURL, NowFunc: cli.now, }, clients) if err != nil { diff --git a/client/go/internal/cli/cmd/feed_test.go b/client/go/internal/cli/cmd/feed_test.go index 467d55a0a6e..097d4ae5fa3 100644 --- a/client/go/internal/cli/cmd/feed_test.go +++ b/client/go/internal/cli/cmd/feed_test.go @@ -24,10 +24,9 @@ func (c *manualClock) now() time.Time { } func TestFeed(t *testing.T) { - httpClient := &mock.HTTPClient{} clock := &manualClock{tick: time.Second} cli, stdout, stderr := newTestCLI(t) - cli.httpClient = httpClient + httpClient := cli.httpClient.(*mock.HTTPClient) cli.now = clock.now td := t.TempDir() diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go index 1b37ff00269..c4012024426 100644 --- a/client/go/internal/cli/cmd/root.go +++ b/client/go/internal/cli/cmd/root.go @@ -47,13 +47,14 @@ type CLI struct { config *Config version version.Version - httpClient util.HTTPClient - auth0Factory auth0Factory - ztsFactory ztsFactory - exec executor - isTerminal func() bool - spinner func(w io.Writer, message string, fn func() error) error - now func() time.Time + httpClient util.HTTPClient + httpClientFactory func(timeout time.Duration) util.HTTPClient + auth0Factory auth0Factory + ztsFactory ztsFactory + exec executor + isTerminal func() bool + spinner func(w io.Writer, message string, fn func() error) error + now func() time.Time } // ErrCLI is an error returned to the user. It wraps an exit status, a regular error and optional hints for resolving @@ -122,17 +123,19 @@ For detailed description of flags and configuration, see 'vespa help config'. if err != nil { return nil, err } + httpClientFactory := util.CreateClient cli := CLI{ Environment: env, Stdin: os.Stdin, Stdout: stdout, Stderr: stderr, - version: version, - cmd: cmd, - httpClient: util.CreateClient(time.Second * 10), - exec: &execSubprocess{}, - now: time.Now, + version: version, + cmd: cmd, + httpClient: httpClientFactory(time.Second * 10), + httpClientFactory: httpClientFactory, + exec: &execSubprocess{}, + now: time.Now, auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) { return auth0.NewClient(httpClient, options) }, diff --git a/client/go/internal/cli/cmd/testutil_test.go b/client/go/internal/cli/cmd/testutil_test.go index 492e40d8855..61d6c15c5a0 100644 --- a/client/go/internal/cli/cmd/testutil_test.go +++ b/client/go/internal/cli/cmd/testutil_test.go @@ -6,6 +6,7 @@ import ( "net/http" "path/filepath" "testing" + "time" "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0" "github.com/vespa-engine/vespa/client/go/internal/mock" @@ -28,6 +29,7 @@ func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Bu t.Fatal(err) } httpClient := &mock.HTTPClient{} + cli.httpClientFactory = func(timeout time.Duration) util.HTTPClient { return httpClient } cli.httpClient = httpClient cli.exec = &mock.Exec{} cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) { diff --git a/client/go/internal/util/http.go b/client/go/internal/util/http.go index 8a67b24dffb..e49d1254f26 100644 --- a/client/go/internal/util/http.go +++ b/client/go/internal/util/http.go @@ -16,7 +16,6 @@ import ( type HTTPClient interface { Do(request *http.Request, timeout time.Duration) (response *http.Response, error error) - Clone() HTTPClient } type defaultHTTPClient struct { @@ -34,8 +33,6 @@ func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (re return c.client.Do(request) } -func (c *defaultHTTPClient) Clone() HTTPClient { return CreateClient(c.client.Timeout) } - func ConfigureTLS(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) { c, ok := client.(*defaultHTTPClient) if !ok { diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go index 9f3fd7f5c65..6dd64dd1275 100644 --- a/client/go/internal/vespa/target.go +++ b/client/go/internal/vespa/target.go @@ -110,7 +110,8 @@ func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Respon return s.httpClient.Do(request, timeout) } -func (s *Service) Client() util.HTTPClient { return s.httpClient } +// SetClient sets the HTTP client that this service should use. +func (s *Service) SetClient(client util.HTTPClient) { s.httpClient = client } // Wait polls the health check of this service until it succeeds or timeout passes. func (s *Service) Wait(timeout time.Duration) (int, error) { |