From 5763fd73d8ad34c4ecd1b75772f136373cdaa733 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Wed, 12 Apr 2023 15:32:13 +0200 Subject: Remove unused pointer receiver --- client/go/internal/vespa/document/http.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'client') diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go index 588330a0574..ae823686a76 100644 --- a/client/go/internal/vespa/document/http.go +++ b/client/go/internal/vespa/document/http.go @@ -166,7 +166,7 @@ func (c *Client) Send(document Document) Result { } defer resp.Body.Close() elapsed := c.now().Sub(start) - return c.resultWithResponse(resp, result, document, elapsed) + return resultWithResponse(resp, result, document, elapsed) } func resultWithErr(result Result, err error) Result { @@ -176,7 +176,7 @@ func resultWithErr(result Result, err error) Result { return result } -func (c *Client) resultWithResponse(resp *http.Response, result Result, document Document, elapsed time.Duration) Result { +func resultWithResponse(resp *http.Response, result Result, document Document, elapsed time.Duration) Result { result.HTTPStatus = resp.StatusCode result.Stats.Responses++ result.Stats.ResponsesByCode = map[int]int64{resp.StatusCode: 1} -- cgit v1.2.3 From 86c3219f28f89a0c5c4c15415b65f4d5a81b8774 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Fri, 14 Apr 2023 10:16:56 +0200 Subject: Wire client options --- client/go/internal/cli/cmd/feed.go | 39 ++++++++++++++++++++----------- client/go/internal/vespa/document/http.go | 6 ++--- 2 files changed, 28 insertions(+), 17 deletions(-) (limited to 'client') diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index 895a22d2be5..19bf84e492a 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -14,16 +14,24 @@ import ( "github.com/vespa-engine/vespa/client/go/internal/vespa/document" ) -func addFeedFlags(cmd *cobra.Command, verbose *bool, connections *int) { - cmd.PersistentFlags().IntVarP(connections, "connections", "N", 8, "The number of connections to use") - cmd.PersistentFlags().BoolVarP(verbose, "verbose", "v", false, "Verbose mode. Print errors as they happen") +func addFeedFlags(cmd *cobra.Command, options *feedOptions) { + cmd.PersistentFlags().IntVar(&options.connections, "connections", 8, "The number of connections to use") + cmd.PersistentFlags().StringVar(&options.route, "route", "", "Target Vespa route for feed operations") + cmd.PersistentFlags().IntVar(&options.traceLevel, "trace", 0, "The trace level of network traffic. 0 to disable") + cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Feed operation timeout in seconds. 0 to disable") + cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print errors as they happen") +} + +type feedOptions struct { + connections int + route string + verbose bool + traceLevel int + timeoutSecs int } func newFeedCmd(cli *CLI) *cobra.Command { - var ( - verbose bool - connections int - ) + var options feedOptions cmd := &cobra.Command{ Use: "feed FILE", Short: "Feed documents to a Vespa cluster", @@ -56,10 +64,10 @@ $ cat documents.jsonl | vespa feed - defer f.Close() r = f } - return feed(r, cli, verbose, connections) + return feed(r, cli, options) }, } - addFeedFlags(cmd, &verbose, &connections) + addFeedFlags(cmd, &options) return cmd } @@ -73,20 +81,23 @@ func createServiceClients(service *vespa.Service, n int) []util.HTTPClient { return clients } -func feed(r io.Reader, cli *CLI, verbose bool, connections int) error { +func feed(r io.Reader, cli *CLI, options feedOptions) error { service, err := documentService(cli) if err != nil { return err } - clients := createServiceClients(service, connections) + clients := createServiceClients(service, options.connections) client := document.NewClient(document.ClientOptions{ - BaseURL: service.BaseURL, + Timeout: time.Duration(options.timeoutSecs) * time.Second, + Route: options.route, + TraceLevel: options.traceLevel, + BaseURL: service.BaseURL, }, clients) - throttler := document.NewThrottler(connections) + throttler := document.NewThrottler(options.connections) // TODO(mpolden): Make doom duration configurable circuitBreaker := document.NewCircuitBreaker(10*time.Second, 0) errWriter := io.Discard - if verbose { + if options.verbose { errWriter = cli.Stderr } dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, errWriter) diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go index ae823686a76..d602821d603 100644 --- a/client/go/internal/vespa/document/http.go +++ b/client/go/internal/vespa/document/http.go @@ -29,7 +29,7 @@ type ClientOptions struct { BaseURL string Timeout time.Duration Route string - TraceLevel *int + TraceLevel int } type countingHTTPClient struct { @@ -78,8 +78,8 @@ func (c *Client) queryParams() url.Values { if c.options.Route != "" { params.Set("route", c.options.Route) } - if c.options.TraceLevel != nil { - params.Set("tracelevel", strconv.Itoa(*c.options.TraceLevel)) + if c.options.TraceLevel > 0 { + params.Set("tracelevel", strconv.Itoa(c.options.TraceLevel)) } return params } -- cgit v1.2.3 From cce3b08cbe1864e80d5b9e57891622706b1d8181 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Fri, 14 Apr 2023 10:35:01 +0200 Subject: Adjust request timeout like Java client --- client/go/internal/vespa/document/http.go | 8 ++++++-- client/go/internal/vespa/document/http_test.go | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) (limited to 'client') diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go index d602821d603..1bcd7eff39e 100644 --- a/client/go/internal/vespa/document/http.go +++ b/client/go/internal/vespa/document/http.go @@ -72,9 +72,13 @@ func NewClient(options ClientOptions, httpClients []util.HTTPClient) *Client { func (c *Client) queryParams() url.Values { params := url.Values{} - if c.options.Timeout > 0 { - params.Set("timeout", strconv.FormatInt(c.options.Timeout.Milliseconds(), 10)+"ms") + timeout := c.options.Timeout + if timeout == 0 { + timeout = 200 * time.Second + } else { + timeout = timeout*11/10 + 1000 } + params.Set("timeout", strconv.FormatInt(timeout.Milliseconds(), 10)+"ms") if c.options.Route != "" { params.Set("route", c.options.Route) } diff --git a/client/go/internal/vespa/document/http_test.go b/client/go/internal/vespa/document/http_test.go index 43eaf1bfdf9..8f8394a5d4e 100644 --- a/client/go/internal/vespa/document/http_test.go +++ b/client/go/internal/vespa/document/http_test.go @@ -108,7 +108,7 @@ func TestClientSend(t *testing.T) { if r.Method != http.MethodPut { t.Errorf("got r.Method = %q, want %q", r.Method, http.MethodPut) } - wantURL := fmt.Sprintf("https://example.com:1337/document/v1/ns/type/docid/%s?create=true&timeout=5000ms", doc.Id.UserSpecific) + wantURL := fmt.Sprintf("https://example.com:1337/document/v1/ns/type/docid/%s?create=true&timeout=5500ms", doc.Id.UserSpecific) if r.URL.String() != wantURL { t.Errorf("got r.URL = %q, want %q", r.URL, wantURL) } -- cgit v1.2.3 From 96d8aae1ec9b4f6130b6b610ce23d2bbdb79298a Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Thu, 13 Apr 2023 15:21:18 +0200 Subject: Support TLS in custom target --- client/go/internal/cli/auth/auth0/auth0.go | 38 +++++---- client/go/internal/cli/auth/zts/zts.go | 28 +++++-- client/go/internal/cli/auth/zts/zts_test.go | 7 +- client/go/internal/cli/cmd/cert.go | 9 ++- client/go/internal/cli/cmd/config.go | 91 +++++++++++++-------- client/go/internal/cli/cmd/config_test.go | 119 ++++++++++++++++++++++++++-- client/go/internal/cli/cmd/curl.go | 12 +-- client/go/internal/cli/cmd/feed.go | 3 +- client/go/internal/cli/cmd/root.go | 97 ++++++++++++----------- client/go/internal/cli/cmd/test.go | 2 +- client/go/internal/cli/cmd/testutil_test.go | 21 ++--- client/go/internal/util/http.go | 29 +++---- client/go/internal/vespa/crypto.go | 2 + client/go/internal/vespa/deploy.go | 8 -- client/go/internal/vespa/deploy_test.go | 2 +- client/go/internal/vespa/target.go | 54 +++++++------ client/go/internal/vespa/target_cloud.go | 93 +++++++--------------- client/go/internal/vespa/target_custom.go | 19 +++-- client/go/internal/vespa/target_test.go | 27 +++---- 19 files changed, 375 insertions(+), 286 deletions(-) (limited to 'client') diff --git a/client/go/internal/cli/auth/auth0/auth0.go b/client/go/internal/cli/auth/auth0/auth0.go index 5f7612d4d2e..6fcd3f7680e 100644 --- a/client/go/internal/cli/auth/auth0/auth0.go +++ b/client/go/internal/cli/auth/auth0/auth0.go @@ -110,28 +110,40 @@ func (a *Client) getDeviceFlowConfig() (flowConfig, error) { } r, err := a.httpClient.Do(req, time.Second*30) if err != nil { - return flowConfig{}, fmt.Errorf("failed to get device flow config: %w", err) + return flowConfig{}, fmt.Errorf("auth0: failed to get device flow config: %w", err) } defer r.Body.Close() if r.StatusCode/100 != 2 { - return flowConfig{}, fmt.Errorf("failed to get device flow config: got response code %d from %s", r.StatusCode, url) + return flowConfig{}, fmt.Errorf("auth0: failed to get device flow config: got response code %d from %s", r.StatusCode, url) } var cfg flowConfig if err := json.NewDecoder(r.Body).Decode(&cfg); err != nil { - return flowConfig{}, fmt.Errorf("failed to decode response: %w", err) + return flowConfig{}, fmt.Errorf("auth0: failed to decode response: %w", err) } return cfg, nil } +func (a *Client) Authenticate(request *http.Request) error { + accessToken, err := a.AccessToken() + if err != nil { + return err + } + if request.Header == nil { + request.Header = make(http.Header) + } + request.Header.Set("Authorization", "Bearer "+accessToken) + return nil +} + // AccessToken returns an access token for the configured system, refreshing it if necessary. func (a *Client) AccessToken() (string, error) { creds, ok := a.provider.Systems[a.options.SystemName] if !ok { - return "", fmt.Errorf("system %s is not configured", a.options.SystemName) + return "", fmt.Errorf("auth0: system %s is not configured: %s", a.options.SystemName, reauthMessage) } else if creds.AccessToken == "" { - return "", fmt.Errorf("access token missing: %s", reauthMessage) + return "", fmt.Errorf("auth0: access token missing: %s", reauthMessage) } else if scopesChanged(creds) { - return "", fmt.Errorf("authentication scopes changed: %s", reauthMessage) + return "", fmt.Errorf("auth0: authentication scopes changed: %s", reauthMessage) } else if isExpired(creds.ExpiresAt, accessTokenExpiry) { // check if the stored access token is expired: // use the refresh token to get a new access token: @@ -142,7 +154,7 @@ func (a *Client) AccessToken() (string, error) { } resp, err := tr.Refresh(cancelOnInterrupt(), a.options.SystemName) if err != nil { - return "", fmt.Errorf("failed to renew access token: %w: %s", err, reauthMessage) + return "", fmt.Errorf("auth0: failed to renew access token: %w: %s", err, reauthMessage) } else { // persist the updated system with renewed access token creds.AccessToken = resp.AccessToken @@ -173,12 +185,6 @@ func scopesChanged(s Credentials) bool { return false } -// HasCredentials returns true if this client has retrived credentials for the configured system. -func (a *Client) HasCredentials() bool { - _, ok := a.provider.Systems[a.options.SystemName] - return ok -} - // WriteCredentials writes given credentials to the configuration file. func (a *Client) WriteCredentials(credentials Credentials) error { if a.provider.Systems == nil { @@ -186,7 +192,7 @@ func (a *Client) WriteCredentials(credentials Credentials) error { } a.provider.Systems[a.options.SystemName] = credentials if err := writeConfig(a.provider, a.options.ConfigPath); err != nil { - return fmt.Errorf("failed to write config: %w", err) + return fmt.Errorf("auth0: failed to write config: %w", err) } return nil } @@ -195,11 +201,11 @@ func (a *Client) WriteCredentials(credentials Credentials) error { func (a *Client) RemoveCredentials() error { tr := &auth.TokenRetriever{Secrets: &auth.Keyring{}} if err := tr.Delete(a.options.SystemName); err != nil { - return fmt.Errorf("failed to remove system %s from secret storage: %w", a.options.SystemName, err) + return fmt.Errorf("auth0: failed to remove system %s from secret storage: %w", a.options.SystemName, err) } delete(a.provider.Systems, a.options.SystemName) if err := writeConfig(a.provider, a.options.ConfigPath); err != nil { - return fmt.Errorf("failed to write config: %w", err) + return fmt.Errorf("auth0: failed to write config: %w", err) } return nil } diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go index caa2d03367d..2c66ff13e8b 100644 --- a/client/go/internal/cli/auth/zts/zts.go +++ b/client/go/internal/cli/auth/zts/zts.go @@ -1,7 +1,6 @@ package zts import ( - "crypto/tls" "encoding/json" "fmt" "net/http" @@ -18,26 +17,39 @@ const DefaultURL = "https://zts.athenz.ouroath.com:4443" type Client struct { client util.HTTPClient tokenURL *url.URL + domain string } // NewClient creates a new client for an Athenz ZTS service located at serviceURL. -func NewClient(client util.HTTPClient, serviceURL string) (*Client, error) { +func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) { tokenURL, err := url.Parse(serviceURL) if err != nil { return nil, err } tokenURL.Path = "/zts/v1/oauth2/token" - return &Client{tokenURL: tokenURL, client: client}, nil + return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil } -// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS. -func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) { - data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain) +func (c *Client) Authenticate(request *http.Request) error { + accessToken, err := c.AccessToken() + if err != nil { + return err + } + if request.Header == nil { + request.Header = make(http.Header) + } + request.Header.Add("Authorization", "Bearer "+accessToken) + return nil +} + +// AccessToken returns an access token within the domain configured in client c. +func (c *Client) AccessToken() (string, error) { + // TODO(mpolden): This should cache and re-use tokens until expiry + data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain) req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data)) if err != nil { return "", err } - util.SetCertificates(c.client, []tls.Certificate{certificate}) response, err := c.client.Do(req, 10*time.Second) if err != nil { return "", err @@ -45,7 +57,7 @@ func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string defer response.Body.Close() if response.StatusCode != http.StatusOK { - return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String()) + return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String()) } var ztsResponse struct { AccessToken string `json:"access_token"` diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go index d0cc7ea9f9d..1c75a94ee03 100644 --- a/client/go/internal/cli/auth/zts/zts_test.go +++ b/client/go/internal/cli/auth/zts/zts_test.go @@ -1,7 +1,6 @@ package zts import ( - "crypto/tls" "testing" "github.com/vespa-engine/vespa/client/go/internal/mock" @@ -9,17 +8,17 @@ import ( func TestAccessToken(t *testing.T) { httpClient := mock.HTTPClient{} - client, err := NewClient(&httpClient, "http://example.com") + client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com") if err != nil { t.Fatal(err) } httpClient.NextResponseString(400, `{"message": "bad request"}`) - _, err = client.AccessToken("vespa.vespa", tls.Certificate{}) + _, err = client.AccessToken() if err == nil { t.Fatal("want error for non-ok response status") } httpClient.NextResponseString(200, `{"access_token": "foo bar"}`) - token, err := client.AccessToken("vespa.vespa", tls.Certificate{}) + token, err := client.AccessToken() if err != nil { t.Fatal(err) } diff --git a/client/go/internal/cli/cmd/cert.go b/client/go/internal/cli/cmd/cert.go index 7f79a9db358..48bad974c3f 100644 --- a/client/go/internal/cli/cmd/cert.go +++ b/client/go/internal/cli/cmd/cert.go @@ -34,13 +34,18 @@ package specified as an argument to this command (default '.'). It's possible to override the private key and certificate used through environment variables. This can be useful in continuous integration systems. -Example of setting the certificate and key in-line: +It's also possible override the CA certificate which can be useful when using self-signed certificates with a +self-hosted Vespa service. See https://docs.vespa.ai/en/mtls.html for more information. +Example of setting the CA certificate, certificate and key in-line: + + export VESPA_CLI_DATA_PLANE_CA_CERT="my CA cert" export VESPA_CLI_DATA_PLANE_CERT="my cert" export VESPA_CLI_DATA_PLANE_KEY="my private key" -Example of loading certificate and key from custom paths: +Example of loading CA certificate, certificate and key from custom paths: + export VESPA_CLI_DATA_PLANE_CA_CERT_FILE=/path/to/cacert export VESPA_CLI_DATA_PLANE_CERT_FILE=/path/to/cert export VESPA_CLI_DATA_PLANE_KEY_FILE=/path/to/key diff --git a/client/go/internal/cli/cmd/config.go b/client/go/internal/cli/cmd/config.go index 2d32c454842..e2132814386 100644 --- a/client/go/internal/cli/cmd/config.go +++ b/client/go/internal/cli/cmd/config.go @@ -19,7 +19,6 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" "github.com/spf13/pflag" - "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0" "github.com/vespa-engine/vespa/client/go/internal/cli/config" "github.com/vespa-engine/vespa/client/go/internal/vespa" ) @@ -250,9 +249,10 @@ type Config struct { } type KeyPair struct { - KeyPair tls.Certificate - CertificateFile string - PrivateKeyFile string + KeyPair tls.Certificate + RootCertificates []byte + CertificateFile string + PrivateKeyFile string } func loadConfig(environment map[string]string, flags map[string]*pflag.Flag) (*Config, error) { @@ -392,6 +392,10 @@ func (c *Config) deploymentIn(system vespa.System) (vespa.Deployment, error) { return vespa.Deployment{System: system, Application: app, Zone: zone}, nil } +func (c *Config) caCertificatePath() string { + return c.environment["VESPA_CLI_DATA_PLANE_CA_CERT_FILE"] +} + func (c *Config) certificatePath(app vespa.ApplicationID, targetType string) (string, error) { if override, ok := c.environment["VESPA_CLI_DATA_PLANE_CERT_FILE"]; ok { return override, nil @@ -412,50 +416,68 @@ func (c *Config) privateKeyPath(app vespa.ApplicationID, targetType string) (str return c.applicationFilePath(app, "data-plane-private-key.pem") } -func (c *Config) x509KeyPair(app vespa.ApplicationID, targetType string) (KeyPair, error) { +func (c *Config) readTLSOptions(app vespa.ApplicationID, targetType string) (vespa.TLSOptions, error) { + _, trustAll := c.environment["VESPA_CLI_DATA_PLANE_TRUST_ALL"] cert, certOk := c.environment["VESPA_CLI_DATA_PLANE_CERT"] key, keyOk := c.environment["VESPA_CLI_DATA_PLANE_KEY"] - var ( - kp tls.Certificate - err error - certFile string - keyFile string - ) + caCertText, caCertOk := c.environment["VESPA_CLI_DATA_PLANE_CA_CERT"] + options := vespa.TLSOptions{TrustAll: trustAll} + // CA certificate + if caCertOk { + options.CACertificate = []byte(caCertText) + } else { + caCertFile := c.caCertificatePath() + if caCertFile != "" { + b, err := os.ReadFile(caCertFile) + if err != nil { + return options, err + } + options.CACertificate = b + options.CACertificateFile = caCertFile + } + } + // Certificate and private key if certOk && keyOk { - // Use key pair from environment - kp, err = tls.X509KeyPair([]byte(cert), []byte(key)) + kp, err := tls.X509KeyPair([]byte(cert), []byte(key)) + if err != nil { + return vespa.TLSOptions{}, err + } + options.KeyPair = []tls.Certificate{kp} } else { - keyFile, err = c.privateKeyPath(app, targetType) + keyFile, err := c.privateKeyPath(app, targetType) if err != nil { - return KeyPair{}, err + return vespa.TLSOptions{}, err } - certFile, err = c.certificatePath(app, targetType) + certFile, err := c.certificatePath(app, targetType) if err != nil { - return KeyPair{}, err + return vespa.TLSOptions{}, err + } + kp, err := tls.LoadX509KeyPair(certFile, keyFile) + if err == nil { + options.KeyPair = []tls.Certificate{kp} + options.PrivateKeyFile = keyFile + options.CertificateFile = certFile + } else if err != nil && !os.IsNotExist(err) { + return vespa.TLSOptions{}, err } - kp, err = tls.LoadX509KeyPair(certFile, keyFile) - } - if err != nil { - return KeyPair{}, err } - if targetType == vespa.TargetHosted { - cert, err := x509.ParseCertificate(kp.Certificate[0]) + if options.KeyPair != nil { + cert, err := x509.ParseCertificate(options.KeyPair[0].Certificate[0]) if err != nil { - return KeyPair{}, err + return vespa.TLSOptions{}, err } now := time.Now() expiredAt := cert.NotAfter if expiredAt.Before(now) { delta := now.Sub(expiredAt).Truncate(time.Second) - return KeyPair{}, fmt.Errorf("certificate %s expired at %s (%s ago)", certFile, cert.NotAfter, delta) + source := options.CertificateFile + if source == "" { + source = "environment" + } + return vespa.TLSOptions{}, fmt.Errorf("certificate in %s expired at %s (%s ago)", source, cert.NotAfter, delta) } - return KeyPair{KeyPair: kp, CertificateFile: certFile, PrivateKeyFile: keyFile}, nil } - return KeyPair{ - KeyPair: kp, - CertificateFile: certFile, - PrivateKeyFile: keyFile, - }, nil + return options, nil } func (c *Config) apiKeyFileFromEnv() (string, bool) { @@ -490,11 +512,10 @@ func (c *Config) readAPIKey(cli *CLI, system vespa.System, tenantName string) ([ return nil, nil // Vespa Cloud CI only talks to data plane and does not have an API key } if !cli.isCI() { - client, err := cli.auth0Factory(cli.httpClient, auth0.Options{ConfigPath: c.authConfigPath(), SystemName: system.Name, SystemURL: system.URL}) - if err == nil && client.HasCredentials() { - return nil, nil // use Auth0 + if _, err := os.Stat(c.authConfigPath()); err == nil { + return nil, nil // We have auth config, so we should prefer Auth0 over API key } - cli.printWarning("Authenticating with API key. This is discouraged in non-CI environments", "Authenticate with 'vespa auth login'") + cli.printWarning("Authenticating with API key. This is discouraged in non-CI environments", "Authenticate with 'vespa auth login' instead") } return os.ReadFile(c.apiKeyPath(tenantName)) } diff --git a/client/go/internal/cli/cmd/config_test.go b/client/go/internal/cli/cmd/config_test.go index 458878b4356..66b65bf402b 100644 --- a/client/go/internal/cli/cmd/config_test.go +++ b/client/go/internal/cli/cmd/config_test.go @@ -2,15 +2,21 @@ package cmd import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0" "github.com/vespa-engine/vespa/client/go/internal/mock" - "github.com/vespa-engine/vespa/client/go/internal/util" "github.com/vespa-engine/vespa/client/go/internal/vespa" ) @@ -166,7 +172,7 @@ func TestReadAPIKey(t *testing.T) { require.Nil(t, err) assert.Equal(t, []byte("foo"), key) - // Cloud CI does not read key from disk as it's not expected to have any + // Cloud CI never reads key from disk as it's not expected to have any cli, _, _ = newTestCLI(t, "VESPA_CLI_CLOUD_CI=true") key, err = cli.config.readAPIKey(cli, vespa.PublicSystem, "t1") require.Nil(t, err) @@ -186,12 +192,111 @@ func TestReadAPIKey(t *testing.T) { require.Nil(t, err) assert.Equal(t, []byte("baz"), key) - // Auth0 is preferred when configured + // Prefer Auth0 if we have auth config cli, _, _ = newTestCLI(t) - cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) { - return &mockAuth0{hasCredentials: true}, nil - } + require.Nil(t, os.WriteFile(filepath.Join(cli.config.homeDir, "auth.json"), []byte("foo"), 0600)) key, err = cli.config.readAPIKey(cli, vespa.PublicSystem, "t1") require.Nil(t, err) assert.Nil(t, key) } + +func TestConfigReadTLSOptions(t *testing.T) { + app := vespa.ApplicationID{Tenant: "t1", Application: "a1", Instance: "i1"} + homeDir := t.TempDir() + + // No environment variables, and no files on disk + assertTLSOptions(t, homeDir, app, vespa.TargetLocal, vespa.TLSOptions{}) + + // A single environment variable is set + assertTLSOptions(t, homeDir, app, vespa.TargetLocal, vespa.TLSOptions{TrustAll: true}, "VESPA_CLI_DATA_PLANE_TRUST_ALL=true") + + // Key pair is provided in-line in environment variables + pemCert, pemKey, keyPair := createKeyPair(t) + assertTLSOptions(t, homeDir, app, + vespa.TargetLocal, + vespa.TLSOptions{ + TrustAll: true, + CACertificate: []byte("cacert"), + KeyPair: []tls.Certificate{keyPair}, + }, + "VESPA_CLI_DATA_PLANE_TRUST_ALL=true", + "VESPA_CLI_DATA_PLANE_CA_CERT=cacert", + "VESPA_CLI_DATA_PLANE_CERT="+string(pemCert), + "VESPA_CLI_DATA_PLANE_KEY="+string(pemKey), + ) + + // Key pair is provided as file paths through environment variables + certFile := filepath.Join(homeDir, "cert") + keyFile := filepath.Join(homeDir, "key") + caCertFile := filepath.Join(homeDir, "cacert") + require.Nil(t, os.WriteFile(certFile, pemCert, 0600)) + require.Nil(t, os.WriteFile(keyFile, pemKey, 0600)) + require.Nil(t, os.WriteFile(caCertFile, []byte("cacert"), 0600)) + assertTLSOptions(t, homeDir, app, + vespa.TargetLocal, + vespa.TLSOptions{ + KeyPair: []tls.Certificate{keyPair}, + CACertificate: []byte("cacert"), + CACertificateFile: caCertFile, + CertificateFile: certFile, + PrivateKeyFile: keyFile, + }, + "VESPA_CLI_DATA_PLANE_CERT_FILE="+certFile, + "VESPA_CLI_DATA_PLANE_KEY_FILE="+keyFile, + "VESPA_CLI_DATA_PLANE_CA_CERT_FILE="+caCertFile, + ) + + // Key pair resides in default paths + defaultCertFile := filepath.Join(homeDir, app.String(), "data-plane-public-cert.pem") + defaultKeyFile := filepath.Join(homeDir, app.String(), "data-plane-private-key.pem") + require.Nil(t, os.WriteFile(defaultCertFile, pemCert, 0600)) + require.Nil(t, os.WriteFile(defaultKeyFile, pemKey, 0600)) + assertTLSOptions(t, homeDir, app, + vespa.TargetLocal, + vespa.TLSOptions{ + KeyPair: []tls.Certificate{keyPair}, + CertificateFile: defaultCertFile, + PrivateKeyFile: defaultKeyFile, + }, + ) +} + +func assertTLSOptions(t *testing.T, homeDir string, app vespa.ApplicationID, target string, want vespa.TLSOptions, envVars ...string) { + t.Helper() + envVars = append(envVars, "VESPA_CLI_HOME="+homeDir) + cli, _, _ := newTestCLI(t, envVars...) + require.Nil(t, cli.Run("config", "set", "application", app.String())) + config, err := cli.config.readTLSOptions(app, vespa.TargetLocal) + require.Nil(t, err) + assert.Equal(t, want, config) +} + +func createKeyPair(t *testing.T) ([]byte, []byte, tls.Certificate) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + notBefore := time.Now() + notAfter := notBefore.Add(24 * time.Hour) + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "example.com"}, + NotBefore: notBefore, + NotAfter: notAfter, + } + certificateDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatal(err) + } + privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + t.Fatal(err) + } + pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certificateDER}) + pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyDER}) + kp, err := tls.X509KeyPair(pemCert, pemKey) + if err != nil { + t.Fatal(err) + } + return pemCert, pemKey, kp +} diff --git a/client/go/internal/cli/cmd/curl.go b/client/go/internal/cli/cmd/curl.go index 8fcd1fa6ef7..3d5aaff24dc 100644 --- a/client/go/internal/cli/cmd/curl.go +++ b/client/go/internal/cli/cmd/curl.go @@ -4,7 +4,6 @@ package cmd import ( "fmt" "log" - "net/http" "os" "strings" @@ -54,6 +53,7 @@ $ vespa curl -- -v --data-urlencode "yql=select * from music where album contain return err } case vespa.DocumentService, vespa.QueryService: + c.CaCertificate = service.TLSOptions.CACertificateFile c.PrivateKey = service.TLSOptions.PrivateKeyFile c.Certificate = service.TLSOptions.CertificateFile default: @@ -79,15 +79,7 @@ func addAccessToken(cmd *curl.Command, target vespa.Target) error { if target.Type() != vespa.TargetCloud { return nil } - req := http.Request{} - if err := target.SignRequest(&req, ""); err != nil { - return err - } - headerValue := req.Header.Get("Authorization") - if headerValue == "" { - return fmt.Errorf("no authorization header added when signing request") - } - cmd.Header("Authorization", headerValue) + cmd.Header("Authorization", "secret") return nil } diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index 19bf84e492a..c284328255a 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -75,7 +75,8 @@ func createServiceClients(service *vespa.Service, n int) []util.HTTPClient { clients := make([]util.HTTPClient, 0, n) for i := 0; i < n; i++ { client := service.Client().Clone() - util.ForceHTTP2(client, service.TLSOptions.KeyPair) // Feeding should always use HTTP/2 + // Feeding should always use HTTP/2 + util.ForceHTTP2(client, service.TLSOptions.KeyPair, service.TLSOptions.CACertificate, service.TLSOptions.TrustAll) clients = append(clients, client) } return clients diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go index 360af9d0dcf..695d2eaca8f 100644 --- a/client/go/internal/cli/cmd/root.go +++ b/client/go/internal/cli/cmd/root.go @@ -2,7 +2,6 @@ package cmd import ( - "crypto/tls" "encoding/json" "fmt" "io" @@ -88,18 +87,9 @@ func (c *execSubprocess) Run(name string, args ...string) ([]byte, error) { return exec.Command(name, args...).Output() } -type ztsClient interface { - AccessToken(domain string, certficiate tls.Certificate) (string, error) -} - -type auth0Client interface { - AccessToken() (string, error) - HasCredentials() bool -} - -type auth0Factory func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) +type auth0Factory func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) -type ztsFactory func(httpClient util.HTTPClient, url string) (ztsClient, error) +type ztsFactory func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) // New creates the Vespa CLI, writing output to stdout and stderr, and reading environment variables from environment. func New(stdout, stderr io.Writer, environment []string) (*CLI, error) { @@ -143,11 +133,11 @@ For detailed description of flags and configuration, see 'vespa help config'. httpClient: util.CreateClient(time.Second * 10), exec: &execSubprocess{}, now: time.Now, - auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) { + auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) { return auth0.NewClient(httpClient, options) }, - ztsFactory: func(httpClient util.HTTPClient, url string) (ztsClient, error) { - return zts.NewClient(httpClient, url) + ztsFactory: func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) { + return zts.NewClient(httpClient, domain, url) }, } cli.isTerminal = func() bool { return isTerminal(cli.Stdout) && isTerminal(cli.Stderr) } @@ -321,16 +311,34 @@ func (c *CLI) createTarget(opts targetOptions) (vespa.Target, error) { if err != nil { return nil, err } + customURL := "" if strings.HasPrefix(targetType, "http") { - return vespa.CustomTarget(c.httpClient, targetType), nil + customURL = targetType + targetType = vespa.TargetCustom } switch targetType { - case vespa.TargetLocal: - return vespa.LocalTarget(c.httpClient), nil + case vespa.TargetLocal, vespa.TargetCustom: + return c.createCustomTarget(targetType, customURL) case vespa.TargetCloud, vespa.TargetHosted: return c.createCloudTarget(targetType, opts) + default: + return nil, errHint(fmt.Errorf("invalid target: %s", targetType), "Valid targets are 'local', 'cloud', 'hosted' or an URL") + } +} + +func (c *CLI) createCustomTarget(targetType, customURL string) (vespa.Target, error) { + tlsOptions, err := c.config.readTLSOptions(vespa.DefaultApplication, targetType) + if err != nil { + return nil, err + } + switch targetType { + case vespa.TargetLocal: + return vespa.LocalTarget(c.httpClient, tlsOptions), nil + case vespa.TargetCustom: + return vespa.CustomTarget(c.httpClient, customURL, tlsOptions), nil + default: + return nil, fmt.Errorf("invalid custom target: %s", targetType) } - return nil, errHint(fmt.Errorf("invalid target: %s", targetType), "Valid targets are 'local', 'cloud', 'hosted' or an URL") } func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Target, error) { @@ -347,48 +355,53 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta return nil, err } var ( - apiKey []byte - authConfigPath string + apiAuth vespa.Authenticator + deploymentAuth vespa.Authenticator apiTLSOptions vespa.TLSOptions deploymentTLSOptions vespa.TLSOptions ) switch targetType { case vespa.TargetCloud: - apiKey, err = c.config.readAPIKey(c, system, deployment.Application.Tenant) + apiKey, err := c.config.readAPIKey(c, system, deployment.Application.Tenant) if err != nil { return nil, err } - authConfigPath = c.config.authConfigPath() + if apiKey == nil { + authConfigPath := c.config.authConfigPath() + auth0, err := c.auth0Factory(c.httpClient, auth0.Options{ConfigPath: authConfigPath, SystemName: system.Name, SystemURL: system.URL}) + if err != nil { + return nil, err + } + apiAuth = auth0 + } else { + apiAuth = vespa.NewRequestSigner(deployment.Application.SerializedForm(), apiKey) + } deploymentTLSOptions = vespa.TLSOptions{} if !opts.noCertificate { - kp, err := c.config.x509KeyPair(deployment.Application, targetType) + kp, err := c.config.readTLSOptions(deployment.Application, targetType) if err != nil { - return nil, errHint(err, "Deployment to cloud requires a certificate. Try 'vespa auth cert'") - } - deploymentTLSOptions = vespa.TLSOptions{ - KeyPair: []tls.Certificate{kp.KeyPair}, - CertificateFile: kp.CertificateFile, - PrivateKeyFile: kp.PrivateKeyFile, + return nil, errHint(err, "Deployment to cloud requires a certificate", "Try 'vespa auth cert' to create a self-signed certificate") } + deploymentTLSOptions = kp } case vespa.TargetHosted: - kp, err := c.config.x509KeyPair(deployment.Application, targetType) + kp, err := c.config.readTLSOptions(deployment.Application, targetType) if err != nil { return nil, errHint(err, "Deployment to hosted requires an Athenz certificate", "Try renewing certificate with 'athenz-user-cert'") } - apiTLSOptions = vespa.TLSOptions{ - KeyPair: []tls.Certificate{kp.KeyPair}, - CertificateFile: kp.CertificateFile, - PrivateKeyFile: kp.PrivateKeyFile, + zts, err := c.ztsFactory(c.httpClient, system.AthenzDomain, zts.DefaultURL) + if err != nil { + return nil, err } - deploymentTLSOptions = apiTLSOptions + deploymentAuth = zts + apiTLSOptions = kp + deploymentTLSOptions = kp default: return nil, fmt.Errorf("invalid cloud target: %s", targetType) } apiOptions := vespa.APIOptions{ System: system, TLSOptions: apiTLSOptions, - APIKey: apiKey, } deploymentOptions := vespa.CloudDeploymentOptions{ Deployment: deployment, @@ -403,15 +416,7 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta Writer: c.Stdout, Level: vespa.LogLevel(logLevel), } - auth0, err := c.auth0Factory(c.httpClient, auth0.Options{ConfigPath: authConfigPath, SystemName: apiOptions.System.Name, SystemURL: apiOptions.System.URL}) - if err != nil { - return nil, err - } - zts, err := c.ztsFactory(c.httpClient, zts.DefaultURL) - if err != nil { - return nil, err - } - return vespa.CloudTarget(c.httpClient, zts, auth0, apiOptions, deploymentOptions, logOptions) + return vespa.CloudTarget(c.httpClient, apiAuth, deploymentAuth, apiOptions, deploymentOptions, logOptions) } // system returns the appropiate system for the target configured in this CLI. diff --git a/client/go/internal/cli/cmd/test.go b/client/go/internal/cli/cmd/test.go index 05633b1135e..8c4501e2870 100644 --- a/client/go/internal/cli/cmd/test.go +++ b/client/go/internal/cli/cmd/test.go @@ -263,7 +263,7 @@ func verify(step step, defaultCluster string, defaultParameters map[string]strin var response *http.Response if externalEndpoint { - util.SetCertificates(context.cli.httpClient, []tls.Certificate{}) + util.ConfigureTLS(context.cli.httpClient, []tls.Certificate{}, nil, false) response, err = context.cli.httpClient.Do(request, 60*time.Second) } else { response, err = service.Do(request, 600*time.Second) // Vespa should provide a response within the given request timeout diff --git a/client/go/internal/cli/cmd/testutil_test.go b/client/go/internal/cli/cmd/testutil_test.go index 61f8dab2264..492e40d8855 100644 --- a/client/go/internal/cli/cmd/testutil_test.go +++ b/client/go/internal/cli/cmd/testutil_test.go @@ -3,13 +3,14 @@ package cmd import ( "bytes" - "crypto/tls" + "net/http" "path/filepath" "testing" "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0" "github.com/vespa-engine/vespa/client/go/internal/mock" "github.com/vespa-engine/vespa/client/go/internal/util" + "github.com/vespa-engine/vespa/client/go/internal/vespa" ) func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Buffer) { @@ -29,21 +30,15 @@ func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Bu httpClient := &mock.HTTPClient{} cli.httpClient = httpClient cli.exec = &mock.Exec{} - cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) { - return &mockAuth0{}, nil + cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) { + return &mockAuthenticator{}, nil } - cli.ztsFactory = func(httpClient util.HTTPClient, url string) (ztsClient, error) { - return &mockZTS{}, nil + cli.ztsFactory = func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) { + return &mockAuthenticator{}, nil } return cli, &stdout, &stderr } -type mockZTS struct{} +type mockAuthenticator struct{} -func (z *mockZTS) AccessToken(domain string, cert tls.Certificate) (string, error) { return "", nil } - -type mockAuth0 struct{ hasCredentials bool } - -func (a *mockAuth0) AccessToken() (string, error) { return "", nil } - -func (a *mockAuth0) HasCredentials() bool { return a.hasCredentials } +func (a *mockAuthenticator) Authenticate(request *http.Request) error { return nil } diff --git a/client/go/internal/util/http.go b/client/go/internal/util/http.go index dcf05ed3a14..8a67b24dffb 100644 --- a/client/go/internal/util/http.go +++ b/client/go/internal/util/http.go @@ -4,6 +4,7 @@ package util import ( "context" "crypto/tls" + "crypto/x509" "fmt" "net" "net/http" @@ -35,7 +36,7 @@ func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (re func (c *defaultHTTPClient) Clone() HTTPClient { return CreateClient(c.client.Timeout) } -func SetCertificates(client HTTPClient, certificates []tls.Certificate) { +func ConfigureTLS(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) { c, ok := client.(*defaultHTTPClient) if !ok { return @@ -43,8 +44,14 @@ func SetCertificates(client HTTPClient, certificates []tls.Certificate) { var tlsConfig *tls.Config = nil if certificates != nil { tlsConfig = &tls.Config{ - Certificates: certificates, - MinVersion: tls.VersionTLS12, + Certificates: certificates, + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: trustAll, + } + if caCertificate != nil { + certs := x509.NewCertPool() + certs.AppendCertsFromPEM(caCertificate) + tlsConfig.RootCAs = certs } } if tr, ok := c.client.Transport.(*http.Transport); ok { @@ -56,19 +63,13 @@ func SetCertificates(client HTTPClient, certificates []tls.Certificate) { } } -func ForceHTTP2(client HTTPClient, certificates []tls.Certificate) { +func ForceHTTP2(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) { c, ok := client.(*defaultHTTPClient) if !ok { return } - var tlsConfig *tls.Config = nil var dialFunc func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) - if certificates != nil { - tlsConfig = &tls.Config{ - Certificates: certificates, - MinVersion: tls.VersionTLS12, - } - } else { + if certificates == nil { // No certificate, so force H2C (HTTP/2 over clear-text) by using a non-TLS Dialer dialer := net.Dialer{} dialFunc = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { @@ -80,10 +81,10 @@ func ForceHTTP2(client HTTPClient, certificates []tls.Certificate) { // https://github.com/golang/go/issues/16582 // https://github.com/golang/go/issues/22091 c.client.Transport = &http2.Transport{ - AllowHTTP: true, - TLSClientConfig: tlsConfig, - DialTLSContext: dialFunc, + AllowHTTP: true, + DialTLSContext: dialFunc, } + ConfigureTLS(client, certificates, caCertificate, trustAll) } func CreateClient(timeout time.Duration) HTTPClient { diff --git a/client/go/internal/vespa/crypto.go b/client/go/internal/vespa/crypto.go index 9621d0c1180..5e273538869 100644 --- a/client/go/internal/vespa/crypto.go +++ b/client/go/internal/vespa/crypto.go @@ -111,6 +111,8 @@ func NewRequestSigner(keyID string, pemPrivateKey []byte) *RequestSigner { } } +func (rs *RequestSigner) Authenticate(request *http.Request) error { return rs.SignRequest(request) } + // SignRequest signs the given HTTP request using the private key in rs func (rs *RequestSigner) SignRequest(request *http.Request) error { timestamp := rs.now().UTC().Format(time.RFC3339) diff --git a/client/go/internal/vespa/deploy.go b/client/go/internal/vespa/deploy.go index 687bfc46124..82fd014b377 100644 --- a/client/go/internal/vespa/deploy.go +++ b/client/go/internal/vespa/deploy.go @@ -263,10 +263,6 @@ func Submit(opts DeploymentOptions) error { } request.Header.Set("Content-Type", writer.FormDataContentType()) serviceDescription := "Submit service" - sigKeyId := opts.Target.Deployment().Application.SerializedForm() - if err := opts.Target.SignRequest(request, sigKeyId); err != nil { - return fmt.Errorf("failed to sign api request: %w", err) - } response, err := opts.HTTPClient.Do(request, time.Minute*10) if err != nil { return err @@ -335,10 +331,6 @@ func uploadApplicationPackage(url *url.URL, opts DeploymentOptions) (PrepareResu return PrepareResult{}, err } - keyID := opts.Target.Deployment().Application.SerializedForm() - if err := opts.Target.SignRequest(request, keyID); err != nil { - return PrepareResult{}, err - } response, err := service.Do(request, time.Minute*10) if err != nil { return PrepareResult{}, err diff --git a/client/go/internal/vespa/deploy_test.go b/client/go/internal/vespa/deploy_test.go index 3e74e9ab3b6..db3d17c432a 100644 --- a/client/go/internal/vespa/deploy_test.go +++ b/client/go/internal/vespa/deploy_test.go @@ -19,7 +19,7 @@ import ( func TestDeploy(t *testing.T) { httpClient := mock.HTTPClient{} - target := LocalTarget(&httpClient) + target := LocalTarget(&httpClient, TLSOptions{}) appDir, _ := mock.ApplicationPackageDir(t, false, false) opts := DeploymentOptions{ Target: target, diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go index bc936623bcb..6d5d7efad91 100644 --- a/client/go/internal/vespa/target.go +++ b/client/go/internal/vespa/target.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "sync" "time" "github.com/vespa-engine/vespa/client/go/internal/util" @@ -17,7 +18,7 @@ const ( // A target for a local Vespa service TargetLocal = "local" - // A target for a custom URL + // A target for a Vespa service at a custom URL TargetCustom = "custom" // A Vespa Cloud target @@ -38,13 +39,19 @@ const ( retryInterval = 2 * time.Second ) +// Authenticator authenticates the given HTTP request. +type Authenticator interface { + Authenticate(request *http.Request) error +} + // Service represents a Vespa service. type Service struct { BaseURL string Name string TLSOptions TLSOptions - zts zts + once sync.Once + auth Authenticator httpClient util.HTTPClient } @@ -65,19 +72,19 @@ type Target interface { // PrintLog writes the logs of this deployment using given options to control output. PrintLog(options LogOptions) error - // SignRequest signs request with given keyID as required by the implementation of this target. - SignRequest(request *http.Request, keyID string) error - // CheckVersion verifies whether clientVersion is compatible with this target. CheckVersion(clientVersion version.Version) error } -// TLSOptions configures the client certificate to use for cloud API or service requests. +// TLSOptions holds the client certificate to use for cloud API or service requests. type TLSOptions struct { - KeyPair []tls.Certificate - CertificateFile string - PrivateKeyFile string - AthenzDomain string + CACertificate []byte + KeyPair []tls.Certificate + TrustAll bool + + CACertificateFile string + CertificateFile string + PrivateKeyFile string } // LogOptions configures the log output to produce when writing log messages. @@ -90,17 +97,15 @@ type LogOptions struct { Level int } -// Do sends request to this service. Any required authentication happens automatically. +// Do sends request to this service. Authentication of the request happens automatically. func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) { - if s.TLSOptions.AthenzDomain != "" && s.TLSOptions.KeyPair != nil { - accessToken, err := s.zts.AccessToken(s.TLSOptions.AthenzDomain, s.TLSOptions.KeyPair[0]) - if err != nil { + s.once.Do(func() { + util.ConfigureTLS(s.httpClient, s.TLSOptions.KeyPair, s.TLSOptions.CACertificate, s.TLSOptions.TrustAll) + }) + if s.auth != nil { + if err := s.auth.Authenticate(request); err != nil { return nil, err } - if request.Header == nil { - request.Header = make(http.Header) - } - request.Header.Add("Authorization", "Bearer "+accessToken) } return s.httpClient.Do(request, timeout) } @@ -118,7 +123,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) { default: return 0, fmt.Errorf("invalid service: %s", s.Name) } - return waitForOK(s.httpClient, url, s.TLSOptions.KeyPair, timeout) + return waitForOK(s, url, timeout) } func (s *Service) Description() string { @@ -141,19 +146,16 @@ type requestFunc func() *http.Request // waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried // until timeout elapses. -func waitForOK(client util.HTTPClient, url string, certificates []tls.Certificate, timeout time.Duration) (int, error) { +func waitForOK(service *Service, url string, timeout time.Duration) (int, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return 0, err } okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil } - return wait(client, okFunc, func() *http.Request { return req }, certificates, timeout) + return wait(service, okFunc, func() *http.Request { return req }, timeout) } -func wait(client util.HTTPClient, fn responseFunc, reqFn requestFunc, certificates []tls.Certificate, timeout time.Duration) (int, error) { - if certificates != nil { - util.SetCertificates(client, certificates) - } +func wait(service *Service, fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) { var ( httpErr error response *http.Response @@ -163,7 +165,7 @@ func wait(client util.HTTPClient, fn responseFunc, reqFn requestFunc, certificat loopOnce := timeout == 0 for time.Now().Before(deadline) || loopOnce { req := reqFn() - response, httpErr = client.Do(req, 10*time.Second) + response, httpErr = service.Do(req, 10*time.Second) if httpErr == nil { statusCode = response.StatusCode body, err := io.ReadAll(response.Body) diff --git a/client/go/internal/vespa/target_cloud.go b/client/go/internal/vespa/target_cloud.go index 1fb3edd78c5..e9dca55f654 100644 --- a/client/go/internal/vespa/target_cloud.go +++ b/client/go/internal/vespa/target_cloud.go @@ -2,7 +2,6 @@ package vespa import ( "bytes" - "crypto/tls" "encoding/json" "fmt" "math" @@ -35,8 +34,8 @@ type cloudTarget struct { deploymentOptions CloudDeploymentOptions logOptions LogOptions httpClient util.HTTPClient - zts zts - auth0 auth0 + apiAuth Authenticator + deploymentAuth Authenticator } type deploymentEndpoint struct { @@ -62,23 +61,15 @@ type logMessage struct { Message string `json:"message"` } -type zts interface { - AccessToken(domain string, certficiate tls.Certificate) (string, error) -} - -type auth0 interface { - AccessToken() (string, error) -} - // CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform. -func CloudTarget(httpClient util.HTTPClient, ztsClient zts, auth0Client auth0, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { +func CloudTarget(httpClient util.HTTPClient, apiAuth Authenticator, deploymentAuth Authenticator, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { return &cloudTarget{ httpClient: httpClient, apiOptions: apiOptions, deploymentOptions: deploymentOptions, logOptions: logOptions, - zts: ztsClient, - auth0: auth0Client, + apiAuth: apiAuth, + deploymentAuth: deploymentAuth, }, nil } @@ -118,15 +109,14 @@ func (t *cloudTarget) IsCloud() bool { return true } func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment } func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) { - var service *Service switch name { case DeployService: - service = &Service{ + service := &Service{ Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, - zts: t.zts, httpClient: t.httpClient, + auth: t.apiAuth, } if timeout > 0 { status, err := service.Wait(timeout) @@ -137,6 +127,7 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) } } + return service, nil case QueryService, DocumentService: if t.deploymentOptions.ClusterURLs == nil { if err := t.waitForEndpoints(timeout, runID); err != nil { @@ -147,38 +138,15 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c if err != nil { return nil, err } - t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain - service = &Service{ + return &Service{ Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, - zts: t.zts, httpClient: t.httpClient, - } - + auth: t.deploymentAuth, + }, nil default: return nil, fmt.Errorf("unknown service: %s", name) - - } - if service.TLSOptions.KeyPair != nil { - util.SetCertificates(service.httpClient, service.TLSOptions.KeyPair) - } - return service, nil -} - -func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error { - if t.apiOptions.System.IsPublic() { - if t.apiOptions.APIKey != nil { - signer := NewRequestSigner(keyID, t.apiOptions.APIKey) - return signer.SignRequest(req) - } else { - return t.addAuth0AccessToken(req) - } - } else { - if t.apiOptions.TLSOptions.KeyPair == nil { - return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name) - } - return nil } } @@ -190,7 +158,11 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { if err != nil { return err } - response, err := t.httpClient.Do(req, 10*time.Second) + deployService, err := t.Service(DeployService, 0, 0, "") + if err != nil { + return err + } + response, err := deployService.Do(req, 10*time.Second) if err != nil { return err } @@ -212,18 +184,6 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { return nil } -func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { - accessToken, err := t.auth0.AccessToken() - if err != nil { - return err - } - if request.Header == nil { - request.Header = make(http.Header) - } - request.Header.Set("Authorization", "Bearer "+accessToken) - return nil -} - func (t *cloudTarget) logsURL() string { return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", t.apiOptions.System.URL, @@ -246,7 +206,6 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { q.Set("to", strconv.FormatInt(toMillis, 10)) } req.URL.RawQuery = q.Encode() - t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()) return req } logFunc := func(status int, response []byte) (bool, error) { @@ -275,10 +234,18 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { if options.Follow { timeout = math.MaxInt64 // No timeout } - _, err = wait(t.httpClient, logFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout) + _, err = t.deployServiceWait(logFunc, requestFunc, timeout) return err } +func (t *cloudTarget) deployServiceWait(fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) { + deployService, err := t.Service(DeployService, 0, 0, "") + if err != nil { + return 0, err + } + return wait(deployService, fn, reqFn, timeout) +} + func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { if runID > 0 { if err := t.waitForRun(runID, timeout); err != nil { @@ -302,9 +269,6 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { q := req.URL.Query() q.Set("after", strconv.FormatInt(lastID, 10)) req.URL.RawQuery = q.Encode() - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - util.JustExitWith(err) - } return req } jobSuccessFunc := func(status int, response []byte) (bool, error) { @@ -326,7 +290,7 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { } return true, nil } - _, err = wait(t.httpClient, jobSuccessFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout) + _, err = t.deployServiceWait(jobSuccessFunc, requestFunc, timeout) return err } @@ -361,9 +325,6 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { if err != nil { return err } - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - return err - } urlsByCluster := make(map[string]string) endpointFunc := func(status int, response []byte) (bool, error) { if ok, err := isCloudOK(status); !ok { @@ -384,7 +345,7 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { } return true, nil } - if _, err = wait(t.httpClient, endpointFunc, func() *http.Request { return req }, t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { + if _, err := t.deployServiceWait(endpointFunc, func() *http.Request { return req }, timeout); err != nil { return err } if len(urlsByCluster) == 0 { diff --git a/client/go/internal/vespa/target_custom.go b/client/go/internal/vespa/target_custom.go index 848d19f0a90..df50e90a55b 100644 --- a/client/go/internal/vespa/target_custom.go +++ b/client/go/internal/vespa/target_custom.go @@ -15,6 +15,7 @@ type customTarget struct { targetType string baseURL string httpClient util.HTTPClient + tlsOptions TLSOptions } type serviceConvergeResponse struct { @@ -22,13 +23,13 @@ type serviceConvergeResponse struct { } // LocalTarget creates a target for a Vespa platform running locally. -func LocalTarget(httpClient util.HTTPClient) Target { - return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1", httpClient: httpClient} +func LocalTarget(httpClient util.HTTPClient, tlsOptions TLSOptions) Target { + return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1", httpClient: httpClient, tlsOptions: tlsOptions} } // CustomTarget creates a Target for a Vespa platform running at baseURL. -func CustomTarget(httpClient util.HTTPClient, baseURL string) Target { - return &customTarget{targetType: TargetCustom, baseURL: baseURL, httpClient: httpClient} +func CustomTarget(httpClient util.HTTPClient, baseURL string, tlsOptions TLSOptions) Target { + return &customTarget{targetType: TargetCustom, baseURL: baseURL, httpClient: httpClient, tlsOptions: tlsOptions} } func (t *customTarget) Type() string { return t.targetType } @@ -44,7 +45,7 @@ func (t *customTarget) createService(name string) (*Service, error) { if err != nil { return nil, err } - return &Service{BaseURL: url, Name: name, httpClient: t.httpClient}, nil + return &Service{BaseURL: url, Name: name, httpClient: t.httpClient, TLSOptions: t.tlsOptions}, nil } return nil, fmt.Errorf("unknown service: %s", name) } @@ -76,8 +77,6 @@ func (t *customTarget) PrintLog(options LogOptions) error { return fmt.Errorf("log access is only supported on cloud: run vespa-logfmt on the admin node instead") } -func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil } - func (t *customTarget) CheckVersion(version version.Version) error { return nil } func (t *customTarget) urlWithPort(serviceName string) (string, error) { @@ -101,11 +100,11 @@ func (t *customTarget) urlWithPort(serviceName string) (string, error) { } func (t *customTarget) waitForConvergence(timeout time.Duration) error { - deployURL, err := t.urlWithPort(DeployService) + deployService, err := t.createService(DeployService) if err != nil { return err } - url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployURL) + url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployService.BaseURL) req, err := http.NewRequest("GET", url, nil) if err != nil { return err @@ -122,7 +121,7 @@ func (t *customTarget) waitForConvergence(timeout time.Duration) error { converged = resp.Converged return converged, nil } - if _, err := wait(t.httpClient, convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil { + if _, err := wait(deployService, convergedFunc, func() *http.Request { return req }, timeout); err != nil { return err } if !converged { diff --git a/client/go/internal/vespa/target_test.go b/client/go/internal/vespa/target_test.go index b9d65f3d8a4..d15001911d0 100644 --- a/client/go/internal/vespa/target_test.go +++ b/client/go/internal/vespa/target_test.go @@ -3,7 +3,6 @@ package vespa import ( "bytes" - "crypto/tls" "fmt" "io" "net/http" @@ -65,17 +64,17 @@ func (v *mockVespaApi) mockVespaHandler(w http.ResponseWriter, req *http.Request } func TestCustomTarget(t *testing.T) { - lt := LocalTarget(&mock.HTTPClient{}) + lt := LocalTarget(&mock.HTTPClient{}, TLSOptions{}) assertServiceURL(t, "http://127.0.0.1:19071", lt, "deploy") assertServiceURL(t, "http://127.0.0.1:8080", lt, "query") assertServiceURL(t, "http://127.0.0.1:8080", lt, "document") - ct := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42") + ct := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42", TLSOptions{}) assertServiceURL(t, "http://192.0.2.42:19071", ct, "deploy") assertServiceURL(t, "http://192.0.2.42:8080", ct, "query") assertServiceURL(t, "http://192.0.2.42:8080", ct, "document") - ct2 := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42:60000") + ct2 := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42:60000", TLSOptions{}) assertServiceURL(t, "http://192.0.2.42:60000", ct2, "deploy") assertServiceURL(t, "http://192.0.2.42:60000", ct2, "query") assertServiceURL(t, "http://192.0.2.42:60000", ct2, "document") @@ -85,7 +84,7 @@ func TestCustomTargetWait(t *testing.T) { vc := mockVespaApi{} srv := httptest.NewServer(http.HandlerFunc(vc.mockVespaHandler)) defer srv.Close() - target := CustomTarget(util.CreateClient(time.Second*10), srv.URL) + target := CustomTarget(util.CreateClient(time.Second*10), srv.URL, TLSOptions{}) _, err := target.Service("query", time.Millisecond, 42, "") assert.NotNil(t, err) @@ -157,10 +156,11 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target { apiKey, err := CreateAPIKey() assert.Nil(t, err) + auth := &mockAuthenticator{} target, err := CloudTarget( util.CreateClient(time.Second*10), - &mockZTS{}, - &mockAuth0{}, + auth, + auth, APIOptions{APIKey: apiKey, System: PublicSystem}, CloudDeploymentOptions{ Deployment: Deployment{ @@ -175,7 +175,6 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target { } if ct, ok := target.(*cloudTarget); ok { ct.apiOptions.System.URL = url - ct.zts = &mockZTS{token: "foo bar"} } else { t.Fatalf("Wrong target type %T", ct) } @@ -197,14 +196,6 @@ func assertServiceWait(t *testing.T, expectedStatus int, target Target, service assert.Equal(t, expectedStatus, status) } -type mockZTS struct{ token string } +type mockAuthenticator struct{} -func (c *mockZTS) AccessToken(domain string, certificate tls.Certificate) (string, error) { - return c.token, nil -} - -type mockAuth0 struct{} - -func (a *mockAuth0) AccessToken() (string, error) { return "", nil } - -func (a *mockAuth0) HasCredentials() bool { return true } +func (a *mockAuthenticator) Authenticate(request *http.Request) error { return nil } -- cgit v1.2.3 From 58fb99d57e3b81f9c1c4567355355ef4a97e989f Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Fri, 14 Apr 2023 13:18:17 +0200 Subject: Always print errors --- client/go/internal/cli/cmd/feed.go | 8 +-- client/go/internal/vespa/document/dispatcher.go | 79 ++++++++++++++-------- .../go/internal/vespa/document/dispatcher_test.go | 6 +- 3 files changed, 56 insertions(+), 37 deletions(-) (limited to 'client') diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go index c284328255a..f0f82dd80d1 100644 --- a/client/go/internal/cli/cmd/feed.go +++ b/client/go/internal/cli/cmd/feed.go @@ -19,7 +19,7 @@ func addFeedFlags(cmd *cobra.Command, options *feedOptions) { cmd.PersistentFlags().StringVar(&options.route, "route", "", "Target Vespa route for feed operations") cmd.PersistentFlags().IntVar(&options.traceLevel, "trace", 0, "The trace level of network traffic. 0 to disable") cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Feed operation timeout in seconds. 0 to disable") - cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print errors as they happen") + cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print successful operations in addition to errors") } type feedOptions struct { @@ -97,11 +97,7 @@ func feed(r io.Reader, cli *CLI, options feedOptions) error { throttler := document.NewThrottler(options.connections) // TODO(mpolden): Make doom duration configurable circuitBreaker := document.NewCircuitBreaker(10*time.Second, 0) - errWriter := io.Discard - if options.verbose { - errWriter = cli.Stderr - } - dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, errWriter) + dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, cli.Stderr, options.verbose) dec := document.NewDecoder(r) start := cli.now() diff --git a/client/go/internal/vespa/document/dispatcher.go b/client/go/internal/vespa/document/dispatcher.go index 838a7bc45ee..798a888d677 100644 --- a/client/go/internal/vespa/document/dispatcher.go +++ b/client/go/internal/vespa/document/dispatcher.go @@ -4,6 +4,7 @@ import ( "container/list" "fmt" "io" + "strings" "sync" "sync/atomic" "time" @@ -18,12 +19,15 @@ type Dispatcher struct { circuitBreaker CircuitBreaker stats Stats - started bool - ready chan Id - results chan Result + started bool + ready chan Id + results chan Result + msgs chan string + inflight map[string]*documentGroup inflightCount int64 - errWriter io.Writer + output io.Writer + verbose bool mu sync.RWMutex wg sync.WaitGroup @@ -55,13 +59,14 @@ func (g *documentGroup) add(op documentOp, first bool) { } } -func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, errWriter io.Writer) *Dispatcher { +func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, output io.Writer, verbose bool) *Dispatcher { d := &Dispatcher{ feeder: feeder, throttler: throttler, circuitBreaker: breaker, inflight: make(map[string]*documentGroup), - errWriter: errWriter, + output: output, + verbose: verbose, } d.start() return d @@ -86,29 +91,35 @@ func (d *Dispatcher) sendDocumentIn(group *documentGroup) { func (d *Dispatcher) shouldRetry(op documentOp, result Result) bool { if result.HTTPStatus/100 == 2 || result.HTTPStatus == 404 || result.HTTPStatus == 412 { + if d.verbose { + d.msgs <- fmt.Sprintf("feed: successfully fed %s with status %d", op.document.Id, result.HTTPStatus) + } d.throttler.Success() d.circuitBreaker.Success() return false } if result.HTTPStatus == 429 || result.HTTPStatus == 503 { - fmt.Fprintf(d.errWriter, "feed: %s was throttled with status %d: retrying\n", op.document, result.HTTPStatus) + d.msgs <- fmt.Sprintf("feed: %s was throttled with status %d: retrying\n", op.document, result.HTTPStatus) d.throttler.Throttled(atomic.LoadInt64(&d.inflightCount)) return true } if result.Err != nil || result.HTTPStatus == 500 || result.HTTPStatus == 502 || result.HTTPStatus == 504 { retry := op.attempts <= maxAttempts - msg := "feed: " + op.document.String() + " failed with " + var msg strings.Builder + msg.WriteString("feed: ") + msg.WriteString(op.document.String()) if result.Err != nil { - msg += "error " + result.Err.Error() + msg.WriteString("error ") + msg.WriteString(result.Err.Error()) } else { - msg += fmt.Sprintf("status %d", result.HTTPStatus) + msg.WriteString(fmt.Sprintf("status %d", result.HTTPStatus)) } if retry { - msg += ": retrying" + msg.WriteString(": retrying") } else { - msg += fmt.Sprintf(": giving up after %d attempts", maxAttempts) + msg.WriteString(fmt.Sprintf(": giving up after %d attempts", maxAttempts)) } - fmt.Fprintln(d.errWriter, msg) + d.msgs <- msg.String() d.circuitBreaker.Error(fmt.Errorf("request failed with status %d", result.HTTPStatus)) if retry { return true @@ -125,17 +136,22 @@ func (d *Dispatcher) start() { } d.ready = make(chan Id, 4096) d.results = make(chan Result, 4096) + d.msgs = make(chan string, 4096) d.started = true d.wg.Add(1) go func() { defer d.wg.Done() d.readDocuments() }() - d.resultWg.Add(1) + d.resultWg.Add(2) go func() { defer d.resultWg.Done() d.readResults() }() + go func() { + defer d.resultWg.Done() + d.readMessages() + }() } func (d *Dispatcher) readDocuments() { @@ -157,6 +173,12 @@ func (d *Dispatcher) readResults() { } } +func (d *Dispatcher) readMessages() { + for msg := range d.msgs { + fmt.Fprintln(d.output, msg) + } +} + func (d *Dispatcher) enqueue(op documentOp) error { d.mu.Lock() if !d.started { @@ -188,25 +210,26 @@ func (d *Dispatcher) acquireSlot() { func (d *Dispatcher) releaseSlot() { atomic.AddInt64(&d.inflightCount, -1) } -func closeAndWait[T any](ch chan T, wg *sync.WaitGroup, d *Dispatcher, markClosed bool) { - d.mu.Lock() - if d.started { - close(ch) - if markClosed { - d.started = false - } - } - d.mu.Unlock() - wg.Wait() -} - func (d *Dispatcher) Enqueue(doc Document) error { return d.enqueue(documentOp{document: doc}) } func (d *Dispatcher) Stats() Stats { return d.stats } // Close closes the dispatcher and waits for all inflight operations to complete. func (d *Dispatcher) Close() error { - closeAndWait(d.ready, &d.wg, d, false) - closeAndWait(d.results, &d.resultWg, d, true) + d.mu.Lock() + if d.started { + close(d.ready) + } + d.mu.Unlock() + d.wg.Wait() // Wait for inflight operations to complete + + d.mu.Lock() + if d.started { + close(d.results) + close(d.msgs) + d.started = false + } + d.mu.Unlock() + d.resultWg.Wait() // Wait for results return nil } diff --git a/client/go/internal/vespa/document/dispatcher_test.go b/client/go/internal/vespa/document/dispatcher_test.go index 80bc5f603ae..d066f5bc9ae 100644 --- a/client/go/internal/vespa/document/dispatcher_test.go +++ b/client/go/internal/vespa/document/dispatcher_test.go @@ -41,7 +41,7 @@ func TestDispatcher(t *testing.T) { clock := &manualClock{tick: time.Second} throttler := newThrottler(8, clock.now) breaker := NewCircuitBreaker(time.Second, 0) - dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard) + dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false) docs := []Document{ {Id: mustParseId("id:ns:type::doc1"), Operation: OperationPut, Body: []byte(`{"fields":{"foo": "123"}}`)}, {Id: mustParseId("id:ns:type::doc2"), Operation: OperationPut, Body: []byte(`{"fields":{"bar": "456"}}`)}, @@ -74,7 +74,7 @@ func TestDispatcherOrdering(t *testing.T) { clock := &manualClock{tick: time.Second} throttler := newThrottler(8, clock.now) breaker := NewCircuitBreaker(time.Second, 0) - dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard) + dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false) for _, d := range docs { dispatcher.Enqueue(d) } @@ -110,7 +110,7 @@ func TestDispatcherOrderingWithFailures(t *testing.T) { clock := &manualClock{tick: time.Second} throttler := newThrottler(8, clock.now) breaker := NewCircuitBreaker(time.Second, 0) - dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard) + dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false) for _, d := range docs { dispatcher.Enqueue(d) } -- cgit v1.2.3 From bb273b823216308a1b33ce2ef6dc7b0a4639494b Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Fri, 14 Apr 2023 15:52:02 +0200 Subject: Create key once --- client/go/internal/vespa/document/dispatcher.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'client') diff --git a/client/go/internal/vespa/document/dispatcher.go b/client/go/internal/vespa/document/dispatcher.go index 798a888d677..dc6045ac791 100644 --- a/client/go/internal/vespa/document/dispatcher.go +++ b/client/go/internal/vespa/document/dispatcher.go @@ -184,10 +184,11 @@ func (d *Dispatcher) enqueue(op documentOp) error { if !d.started { return fmt.Errorf("dispatcher is closed") } - group, ok := d.inflight[op.document.Id.String()] + key := op.document.Id.String() + group, ok := d.inflight[key] if !ok { group = &documentGroup{} - d.inflight[op.document.Id.String()] = group + d.inflight[key] = group } d.mu.Unlock() group.add(op, op.attempts > 0) -- cgit v1.2.3 From 468dc5f1a47f3d7d90ae7e83476344c55c20b149 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Fri, 14 Apr 2023 16:14:51 +0200 Subject: Release lock before retrying --- client/go/internal/vespa/document/dispatcher.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'client') diff --git a/client/go/internal/vespa/document/dispatcher.go b/client/go/internal/vespa/document/dispatcher.go index dc6045ac791..533ca7a0019 100644 --- a/client/go/internal/vespa/document/dispatcher.go +++ b/client/go/internal/vespa/document/dispatcher.go @@ -74,8 +74,6 @@ func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, o func (d *Dispatcher) sendDocumentIn(group *documentGroup) { group.mu.Lock() - defer group.mu.Unlock() - defer d.releaseSlot() first := group.ops.Front() if first == nil { panic("sending from empty document group, this should not happen") @@ -84,6 +82,8 @@ func (d *Dispatcher) sendDocumentIn(group *documentGroup) { op.attempts++ result := d.feeder.Send(op.document) d.results <- result + d.releaseSlot() + group.mu.Unlock() if d.shouldRetry(op, result) { d.enqueue(op) } -- cgit v1.2.3 From 7040b8c2c454a79316f800a5c4a9977e96905b81 Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Mon, 17 Apr 2023 10:27:31 +0200 Subject: Never wait on 4xx for any target --- client/go/internal/cli/cmd/root.go | 1 - client/go/internal/vespa/deploy.go | 16 +++++++++++----- client/go/internal/vespa/deploy_test.go | 2 -- client/go/internal/vespa/target.go | 22 +++++++++++++++++++--- client/go/internal/vespa/target_cloud.go | 16 ++++------------ client/go/internal/vespa/target_custom.go | 6 +++--- client/go/internal/vespa/target_test.go | 19 ++++++++++++++++++- 7 files changed, 55 insertions(+), 27 deletions(-) (limited to 'client') diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go index 695d2eaca8f..88d43411983 100644 --- a/client/go/internal/cli/cmd/root.go +++ b/client/go/internal/cli/cmd/root.go @@ -465,7 +465,6 @@ func (c *CLI) createDeploymentOptions(pkg vespa.ApplicationPackage, target vespa ApplicationPackage: pkg, Target: target, Timeout: timeout, - HTTPClient: c.httpClient, }, nil } diff --git a/client/go/internal/vespa/deploy.go b/client/go/internal/vespa/deploy.go index 82fd014b377..f633c8ed9ee 100644 --- a/client/go/internal/vespa/deploy.go +++ b/client/go/internal/vespa/deploy.go @@ -45,7 +45,6 @@ type DeploymentOptions struct { ApplicationPackage ApplicationPackage Timeout time.Duration Version version.Version - HTTPClient util.HTTPClient } type LogLinePrepareResponse struct { @@ -130,7 +129,7 @@ func Prepare(deployment DeploymentOptions) (PrepareResult, error) { return PrepareResult{}, err } serviceDescription := "Deploy service" - response, err := deployment.HTTPClient.Do(req, time.Second*30) + response, err := deployServiceDo(req, time.Second*30, deployment) if err != nil { return PrepareResult{}, err } @@ -171,7 +170,7 @@ func Activate(sessionID int64, deployment DeploymentOptions) error { return err } serviceDescription := "Deploy service" - response, err := deployment.HTTPClient.Do(req, time.Second*30) + response, err := deployServiceDo(req, time.Second*30, deployment) if err != nil { return err } @@ -263,7 +262,7 @@ func Submit(opts DeploymentOptions) error { } request.Header.Set("Content-Type", writer.FormDataContentType()) serviceDescription := "Submit service" - response, err := opts.HTTPClient.Do(request, time.Minute*10) + response, err := deployServiceDo(request, time.Minute*10, opts) if err != nil { return err } @@ -271,6 +270,14 @@ func Submit(opts DeploymentOptions) error { return checkResponse(request, response, serviceDescription) } +func deployServiceDo(request *http.Request, timeout time.Duration, opts DeploymentOptions) (*http.Response, error) { + s, err := opts.Target.Service(DeployService, 0, 0, "") + if err != nil { + return nil, err + } + return s.Do(request, timeout) +} + func checkDeploymentOpts(opts DeploymentOptions) error { if opts.Target.Type() == TargetCloud && !opts.ApplicationPackage.HasCertificate() { return fmt.Errorf("%s: missing certificate in package", opts) @@ -330,7 +337,6 @@ func uploadApplicationPackage(url *url.URL, opts DeploymentOptions) (PrepareResu if err != nil { return PrepareResult{}, err } - response, err := service.Do(request, time.Minute*10) if err != nil { return PrepareResult{}, err diff --git a/client/go/internal/vespa/deploy_test.go b/client/go/internal/vespa/deploy_test.go index db3d17c432a..da2604282c0 100644 --- a/client/go/internal/vespa/deploy_test.go +++ b/client/go/internal/vespa/deploy_test.go @@ -24,7 +24,6 @@ func TestDeploy(t *testing.T) { opts := DeploymentOptions{ Target: target, ApplicationPackage: ApplicationPackage{Path: appDir}, - HTTPClient: &httpClient, } _, err := Deploy(opts) assert.Nil(t, err) @@ -47,7 +46,6 @@ func TestDeployCloud(t *testing.T) { opts := DeploymentOptions{ Target: target, ApplicationPackage: ApplicationPackage{Path: appDir}, - HTTPClient: &httpClient, } _, err := Deploy(opts) require.Nil(t, err) diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go index 6d5d7efad91..9f3fd7f5c65 100644 --- a/client/go/internal/vespa/target.go +++ b/client/go/internal/vespa/target.go @@ -138,20 +138,36 @@ func (s *Service) Description() string { return fmt.Sprintf("No description of service %s", s.Name) } -func isOK(status int) bool { return status/100 == 2 } +func isOK(status int) (bool, error) { + class := status / 100 + switch class { + case 2: // success + return true, nil + case 4: // client error + return false, fmt.Errorf("request failed with status %d", status) + default: // retry + return false, nil + } +} type responseFunc func(status int, response []byte) (bool, error) type requestFunc func() *http.Request -// waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried +// waitForOK queries url and returns its status code. If response status is not 2xx or 4xx, it is repeatedly queried // until timeout elapses. func waitForOK(service *Service, url string, timeout time.Duration) (int, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return 0, err } - okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil } + okFunc := func(status int, response []byte) (bool, error) { + ok, err := isOK(status) + if err != nil { + return false, fmt.Errorf("failed to query %s at %s: %w", service.Description(), url, err) + } + return ok, err + } return wait(service, okFunc, func() *http.Request { return req }, timeout) } diff --git a/client/go/internal/vespa/target_cloud.go b/client/go/internal/vespa/target_cloud.go index e9dca55f654..928bb788494 100644 --- a/client/go/internal/vespa/target_cloud.go +++ b/client/go/internal/vespa/target_cloud.go @@ -123,7 +123,7 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c if err != nil { return nil, err } - if !isOK(status) { + if ok, _ := isOK(status); !ok { return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) } } @@ -209,7 +209,7 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { return req } logFunc := func(status int, response []byte) (bool, error) { - if ok, err := isCloudOK(status); !ok { + if ok, err := isOK(status); !ok { return ok, err } logEntries, err := ReadLogEntries(bytes.NewReader(response)) @@ -272,7 +272,7 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { return req } jobSuccessFunc := func(status int, response []byte) (bool, error) { - if ok, err := isCloudOK(status); !ok { + if ok, err := isOK(status); !ok { return ok, err } var resp jobResponse @@ -327,7 +327,7 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { } urlsByCluster := make(map[string]string) endpointFunc := func(status int, response []byte) (bool, error) { - if ok, err := isCloudOK(status); !ok { + if ok, err := isOK(status); !ok { return ok, err } var resp deploymentResponse @@ -354,11 +354,3 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { t.deploymentOptions.ClusterURLs = urlsByCluster return nil } - -func isCloudOK(status int) (bool, error) { - if status == 401 { - // when retrying we should give up immediately if we're not authorized - return false, fmt.Errorf("status %d: invalid credentials", status) - } - return isOK(status), nil -} diff --git a/client/go/internal/vespa/target_custom.go b/client/go/internal/vespa/target_custom.go index df50e90a55b..0a3a9d48fed 100644 --- a/client/go/internal/vespa/target_custom.go +++ b/client/go/internal/vespa/target_custom.go @@ -61,7 +61,7 @@ func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunI if err != nil { return nil, err } - if !isOK(status) { + if ok, _ := isOK(status); !ok { return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) } } else { @@ -111,8 +111,8 @@ func (t *customTarget) waitForConvergence(timeout time.Duration) error { } converged := false convergedFunc := func(status int, response []byte) (bool, error) { - if !isOK(status) { - return false, nil + if ok, err := isOK(status); !ok { + return ok, err } var resp serviceConvergeResponse if err := json.Unmarshal(response, &resp); err != nil { diff --git a/client/go/internal/vespa/target_test.go b/client/go/internal/vespa/target_test.go index d15001911d0..bf266e8f9ec 100644 --- a/client/go/internal/vespa/target_test.go +++ b/client/go/internal/vespa/target_test.go @@ -18,10 +18,16 @@ import ( type mockVespaApi struct { deploymentConverged bool + authFailure bool serverURL string } func (v *mockVespaApi) mockVespaHandler(w http.ResponseWriter, req *http.Request) { + if v.authFailure { + response := `{"message":"unauthorized"}` + w.WriteHeader(401) + w.Write([]byte(response)) + } switch req.URL.Path { case "/cli/v1/": response := `{"minVersion":"8.0.0"}` @@ -106,6 +112,9 @@ func TestCloudTargetWait(t *testing.T) { var logWriter bytes.Buffer target := createCloudTarget(t, srv.URL, &logWriter) + vc.authFailure = true + assertServiceWaitErr(t, 401, true, target, "deploy") + vc.authFailure = false assertServiceWait(t, 200, target, "deploy") _, err := target.Service("query", time.Millisecond, 42, "") @@ -188,11 +197,19 @@ func assertServiceURL(t *testing.T, url string, target Target, service string) { } func assertServiceWait(t *testing.T, expectedStatus int, target Target, service string) { + assertServiceWaitErr(t, expectedStatus, false, target, service) +} + +func assertServiceWaitErr(t *testing.T, expectedStatus int, expectErr bool, target Target, service string) { s, err := target.Service(service, 0, 42, "") assert.Nil(t, err) status, err := s.Wait(0) - assert.Nil(t, err) + if expectErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } assert.Equal(t, expectedStatus, status) } -- cgit v1.2.3