summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--client/go/internal/cli/auth/zts/zts.go62
-rw-r--r--client/go/internal/cli/auth/zts/zts_test.go37
-rw-r--r--client/go/internal/cli/cmd/feed.go33
-rw-r--r--client/go/internal/cli/cmd/feed_test.go3
-rw-r--r--client/go/internal/cli/cmd/root.go27
-rw-r--r--client/go/internal/cli/cmd/testutil_test.go2
-rw-r--r--client/go/internal/util/http.go3
-rw-r--r--client/go/internal/vespa/target.go3
8 files changed, 124 insertions, 46 deletions
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go
index 2c66ff13e8b..0f73ea5912d 100644
--- a/client/go/internal/cli/auth/zts/zts.go
+++ b/client/go/internal/cli/auth/zts/zts.go
@@ -3,23 +3,39 @@ package zts
import (
"encoding/json"
"fmt"
+ "io"
"net/http"
"net/url"
"strings"
+ "sync"
"time"
"github.com/vespa-engine/vespa/client/go/internal/util"
)
-const DefaultURL = "https://zts.athenz.ouroath.com:4443"
+const (
+ DefaultURL = "https://zts.athenz.ouroath.com:4443"
+ expirySlack = 5 * time.Minute
+)
// Client is a client for Athenz ZTS, an authentication token service.
type Client struct {
client util.HTTPClient
tokenURL *url.URL
domain string
+ now func() time.Time
+ token Token
+ mu sync.Mutex
+}
+
+// Token is an access token retrieved from ZTS.
+type Token struct {
+ Value string
+ ExpiresAt time.Time
}
+func (t *Token) isExpired(now time.Time) bool { return t.ExpiresAt.Sub(now) < expirySlack }
+
// NewClient creates a new client for an Athenz ZTS service located at serviceURL.
func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) {
tokenURL, err := url.Parse(serviceURL)
@@ -27,44 +43,60 @@ func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, erro
return nil, err
}
tokenURL.Path = "/zts/v1/oauth2/token"
- return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil
+ return &Client{tokenURL: tokenURL, client: client, domain: domain, now: time.Now}, nil
}
func (c *Client) Authenticate(request *http.Request) error {
- accessToken, err := c.AccessToken()
- if err != nil {
- return err
+ now := c.now()
+ if c.token.isExpired(now) {
+ c.mu.Lock()
+ if c.token.isExpired(now) {
+ accessToken, err := c.AccessToken()
+ if err != nil {
+ c.mu.Unlock()
+ return err
+ }
+ c.token = accessToken
+ }
+ c.mu.Unlock()
}
if request.Header == nil {
request.Header = make(http.Header)
}
- request.Header.Add("Authorization", "Bearer "+accessToken)
+ request.Header.Add("Authorization", "Bearer "+c.token.Value)
return nil
}
// AccessToken returns an access token within the domain configured in client c.
-func (c *Client) AccessToken() (string, error) {
- // TODO(mpolden): This should cache and re-use tokens until expiry
+func (c *Client) AccessToken() (Token, error) {
data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain)
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data))
if err != nil {
- return "", err
+ return Token{}, err
}
+ now := c.now()
response, err := c.client.Do(req, 10*time.Second)
if err != nil {
- return "", err
+ return Token{}, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
- return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
+ return Token{}, fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
}
var ztsResponse struct {
AccessToken string `json:"access_token"`
+ ExpirySecs int `json:"expires_in"`
+ }
+ b, err := io.ReadAll(response.Body)
+ if err != nil {
+ return Token{}, err
}
- dec := json.NewDecoder(response.Body)
- if err := dec.Decode(&ztsResponse); err != nil {
- return "", err
+ if err := json.Unmarshal(b, &ztsResponse); err != nil {
+ return Token{}, err
}
- return ztsResponse.AccessToken, nil
+ return Token{
+ Value: ztsResponse.AccessToken,
+ ExpiresAt: now.Add(time.Duration(ztsResponse.ExpirySecs) * time.Second),
+ }, nil
}
diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go
index 1c75a94ee03..15c60ed46d7 100644
--- a/client/go/internal/cli/auth/zts/zts_test.go
+++ b/client/go/internal/cli/auth/zts/zts_test.go
@@ -2,28 +2,57 @@ package zts
import (
"testing"
+ "time"
"github.com/vespa-engine/vespa/client/go/internal/mock"
)
+type manualClock struct{ t time.Time }
+
+func (c *manualClock) now() time.Time { return c.t }
+func (c *manualClock) advance(d time.Duration) { c.t = c.t.Add(d) }
+
func TestAccessToken(t *testing.T) {
httpClient := mock.HTTPClient{}
client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com")
if err != nil {
t.Fatal(err)
}
+ clock := &manualClock{t: time.Now()}
+ client.now = clock.now
httpClient.NextResponseString(400, `{"message": "bad request"}`)
_, err = client.AccessToken()
if err == nil {
t.Fatal("want error for non-ok response status")
}
- httpClient.NextResponseString(200, `{"access_token": "foo bar"}`)
+ httpClient.NextResponseString(200, `{"access_token": "foo", "expires_in": 3600}`)
token, err := client.AccessToken()
if err != nil {
t.Fatal(err)
}
- want := "foo bar"
- if token != want {
- t.Errorf("got %q, want %q", token, want)
+
+ // Token is cached
+ expiresAt := clock.now().Add(time.Hour)
+ assertToken(t, Token{Value: "foo", ExpiresAt: expiresAt}, token)
+ clock.advance(54 * time.Minute)
+ assertToken(t, Token{Value: "foo", ExpiresAt: expiresAt}, token)
+
+ // Token is renewed when nearing expiry
+ clock.advance(time.Minute + time.Second)
+ httpClient.NextResponseString(200, `{"access_token": "bar", "expires_in": 1800}`)
+ token, err = client.AccessToken()
+ if err != nil {
+ t.Fatal(err)
+ }
+ expiresAt = clock.now().Add(30 * time.Minute)
+ assertToken(t, Token{Value: "bar", ExpiresAt: expiresAt}, token)
+}
+
+func assertToken(t *testing.T, want, got Token) {
+ if want.Value != got.Value {
+ t.Errorf("got Value=%q, want %q", got.Value, want.Value)
+ }
+ if want.ExpiresAt != got.ExpiresAt {
+ t.Errorf("got ExpiresAt=%s, want %s", got.ExpiresAt, want.ExpiresAt)
}
}
diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go
index 5b168ef79a2..2ea3ee1c4ed 100644
--- a/client/go/internal/cli/cmd/feed.go
+++ b/client/go/internal/cli/cmd/feed.go
@@ -93,15 +93,30 @@ $ cat docs.jsonl | vespa feed -`,
return cmd
}
-func createServiceClients(service *vespa.Service, n int) []util.HTTPClient {
- clients := make([]util.HTTPClient, 0, n)
+func createServices(n int, timeout time.Duration, cli *CLI) ([]util.HTTPClient, string, error) {
+ if n < 1 {
+ return nil, "", fmt.Errorf("need at least one client")
+ }
+ target, err := cli.target(targetOptions{})
+ if err != nil {
+ return nil, "", err
+ }
+ services := make([]util.HTTPClient, 0, n)
+ baseURL := ""
for i := 0; i < n; i++ {
- client := service.Client().Clone()
+ service, err := cli.service(target, vespa.DocumentService, 0, cli.config.cluster())
+ if err != nil {
+ return nil, "", err
+ }
+ baseURL = service.BaseURL
+ // Create a separate HTTP client for each service
+ client := cli.httpClientFactory(timeout)
// Feeding should always use HTTP/2
util.ForceHTTP2(client, service.TLSOptions.KeyPair, service.TLSOptions.CACertificate, service.TLSOptions.TrustAll)
- clients = append(clients, client)
+ service.SetClient(client)
+ services = append(services, service)
}
- return clients
+ return services, baseURL, nil
}
func summaryTicker(secs int, cli *CLI, start time.Time, statsFunc func() document.Stats) *time.Ticker {
@@ -130,21 +145,21 @@ func (opts feedOptions) compressionMode() (document.Compression, error) {
}
func feed(files []string, options feedOptions, cli *CLI) error {
- service, err := documentService(cli)
+ timeout := time.Duration(options.timeoutSecs) * time.Second
+ clients, baseURL, err := createServices(options.connections, timeout, cli)
if err != nil {
return err
}
- clients := createServiceClients(service, options.connections)
compression, err := options.compressionMode()
if err != nil {
return err
}
client, err := document.NewClient(document.ClientOptions{
Compression: compression,
- Timeout: time.Duration(options.timeoutSecs) * time.Second,
+ Timeout: timeout,
Route: options.route,
TraceLevel: options.traceLevel,
- BaseURL: service.BaseURL,
+ BaseURL: baseURL,
NowFunc: cli.now,
}, clients)
if err != nil {
diff --git a/client/go/internal/cli/cmd/feed_test.go b/client/go/internal/cli/cmd/feed_test.go
index 467d55a0a6e..097d4ae5fa3 100644
--- a/client/go/internal/cli/cmd/feed_test.go
+++ b/client/go/internal/cli/cmd/feed_test.go
@@ -24,10 +24,9 @@ func (c *manualClock) now() time.Time {
}
func TestFeed(t *testing.T) {
- httpClient := &mock.HTTPClient{}
clock := &manualClock{tick: time.Second}
cli, stdout, stderr := newTestCLI(t)
- cli.httpClient = httpClient
+ httpClient := cli.httpClient.(*mock.HTTPClient)
cli.now = clock.now
td := t.TempDir()
diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go
index 1b37ff00269..c4012024426 100644
--- a/client/go/internal/cli/cmd/root.go
+++ b/client/go/internal/cli/cmd/root.go
@@ -47,13 +47,14 @@ type CLI struct {
config *Config
version version.Version
- httpClient util.HTTPClient
- auth0Factory auth0Factory
- ztsFactory ztsFactory
- exec executor
- isTerminal func() bool
- spinner func(w io.Writer, message string, fn func() error) error
- now func() time.Time
+ httpClient util.HTTPClient
+ httpClientFactory func(timeout time.Duration) util.HTTPClient
+ auth0Factory auth0Factory
+ ztsFactory ztsFactory
+ exec executor
+ isTerminal func() bool
+ spinner func(w io.Writer, message string, fn func() error) error
+ now func() time.Time
}
// ErrCLI is an error returned to the user. It wraps an exit status, a regular error and optional hints for resolving
@@ -122,17 +123,19 @@ For detailed description of flags and configuration, see 'vespa help config'.
if err != nil {
return nil, err
}
+ httpClientFactory := util.CreateClient
cli := CLI{
Environment: env,
Stdin: os.Stdin,
Stdout: stdout,
Stderr: stderr,
- version: version,
- cmd: cmd,
- httpClient: util.CreateClient(time.Second * 10),
- exec: &execSubprocess{},
- now: time.Now,
+ version: version,
+ cmd: cmd,
+ httpClient: httpClientFactory(time.Second * 10),
+ httpClientFactory: httpClientFactory,
+ exec: &execSubprocess{},
+ now: time.Now,
auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) {
return auth0.NewClient(httpClient, options)
},
diff --git a/client/go/internal/cli/cmd/testutil_test.go b/client/go/internal/cli/cmd/testutil_test.go
index 492e40d8855..61d6c15c5a0 100644
--- a/client/go/internal/cli/cmd/testutil_test.go
+++ b/client/go/internal/cli/cmd/testutil_test.go
@@ -6,6 +6,7 @@ import (
"net/http"
"path/filepath"
"testing"
+ "time"
"github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/mock"
@@ -28,6 +29,7 @@ func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Bu
t.Fatal(err)
}
httpClient := &mock.HTTPClient{}
+ cli.httpClientFactory = func(timeout time.Duration) util.HTTPClient { return httpClient }
cli.httpClient = httpClient
cli.exec = &mock.Exec{}
cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) {
diff --git a/client/go/internal/util/http.go b/client/go/internal/util/http.go
index 8a67b24dffb..e49d1254f26 100644
--- a/client/go/internal/util/http.go
+++ b/client/go/internal/util/http.go
@@ -16,7 +16,6 @@ import (
type HTTPClient interface {
Do(request *http.Request, timeout time.Duration) (response *http.Response, error error)
- Clone() HTTPClient
}
type defaultHTTPClient struct {
@@ -34,8 +33,6 @@ func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (re
return c.client.Do(request)
}
-func (c *defaultHTTPClient) Clone() HTTPClient { return CreateClient(c.client.Timeout) }
-
func ConfigureTLS(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) {
c, ok := client.(*defaultHTTPClient)
if !ok {
diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go
index 9f3fd7f5c65..6dd64dd1275 100644
--- a/client/go/internal/vespa/target.go
+++ b/client/go/internal/vespa/target.go
@@ -110,7 +110,8 @@ func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Respon
return s.httpClient.Do(request, timeout)
}
-func (s *Service) Client() util.HTTPClient { return s.httpClient }
+// SetClient sets the HTTP client that this service should use.
+func (s *Service) SetClient(client util.HTTPClient) { s.httpClient = client }
// Wait polls the health check of this service until it succeeds or timeout passes.
func (s *Service) Wait(timeout time.Duration) (int, error) {