aboutsummaryrefslogtreecommitdiffstats
path: root/client
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-04-18 11:18:04 +0200
committerGitHub <noreply@github.com>2023-04-18 11:18:04 +0200
commitf871ed1ecec477f016168293cddb59edf4b2b1a7 (patch)
tree52eced0d7613def0d507a93d83fbf2640e44fbdd /client
parent95805d16d0b043f0149edf733876067f1a8ee8e7 (diff)
parent7040b8c2c454a79316f800a5c4a9977e96905b81 (diff)
Merge pull request #26751 from vespa-engine/mpolden/feed-client-7
Support TLS in custom target
Diffstat (limited to 'client')
-rw-r--r--client/go/internal/cli/auth/auth0/auth0.go38
-rw-r--r--client/go/internal/cli/auth/zts/zts.go28
-rw-r--r--client/go/internal/cli/auth/zts/zts_test.go7
-rw-r--r--client/go/internal/cli/cmd/cert.go9
-rw-r--r--client/go/internal/cli/cmd/config.go91
-rw-r--r--client/go/internal/cli/cmd/config_test.go119
-rw-r--r--client/go/internal/cli/cmd/curl.go12
-rw-r--r--client/go/internal/cli/cmd/feed.go46
-rw-r--r--client/go/internal/cli/cmd/root.go98
-rw-r--r--client/go/internal/cli/cmd/test.go2
-rw-r--r--client/go/internal/cli/cmd/testutil_test.go21
-rw-r--r--client/go/internal/util/http.go29
-rw-r--r--client/go/internal/vespa/crypto.go2
-rw-r--r--client/go/internal/vespa/deploy.go24
-rw-r--r--client/go/internal/vespa/deploy_test.go4
-rw-r--r--client/go/internal/vespa/document/dispatcher.go88
-rw-r--r--client/go/internal/vespa/document/dispatcher_test.go6
-rw-r--r--client/go/internal/vespa/document/http.go18
-rw-r--r--client/go/internal/vespa/document/http_test.go2
-rw-r--r--client/go/internal/vespa/target.go76
-rw-r--r--client/go/internal/vespa/target_cloud.go109
-rw-r--r--client/go/internal/vespa/target_custom.go25
-rw-r--r--client/go/internal/vespa/target_test.go46
23 files changed, 526 insertions, 374 deletions
diff --git a/client/go/internal/cli/auth/auth0/auth0.go b/client/go/internal/cli/auth/auth0/auth0.go
index 5f7612d4d2e..6fcd3f7680e 100644
--- a/client/go/internal/cli/auth/auth0/auth0.go
+++ b/client/go/internal/cli/auth/auth0/auth0.go
@@ -110,28 +110,40 @@ func (a *Client) getDeviceFlowConfig() (flowConfig, error) {
}
r, err := a.httpClient.Do(req, time.Second*30)
if err != nil {
- return flowConfig{}, fmt.Errorf("failed to get device flow config: %w", err)
+ return flowConfig{}, fmt.Errorf("auth0: failed to get device flow config: %w", err)
}
defer r.Body.Close()
if r.StatusCode/100 != 2 {
- return flowConfig{}, fmt.Errorf("failed to get device flow config: got response code %d from %s", r.StatusCode, url)
+ return flowConfig{}, fmt.Errorf("auth0: failed to get device flow config: got response code %d from %s", r.StatusCode, url)
}
var cfg flowConfig
if err := json.NewDecoder(r.Body).Decode(&cfg); err != nil {
- return flowConfig{}, fmt.Errorf("failed to decode response: %w", err)
+ return flowConfig{}, fmt.Errorf("auth0: failed to decode response: %w", err)
}
return cfg, nil
}
+func (a *Client) Authenticate(request *http.Request) error {
+ accessToken, err := a.AccessToken()
+ if err != nil {
+ return err
+ }
+ if request.Header == nil {
+ request.Header = make(http.Header)
+ }
+ request.Header.Set("Authorization", "Bearer "+accessToken)
+ return nil
+}
+
// AccessToken returns an access token for the configured system, refreshing it if necessary.
func (a *Client) AccessToken() (string, error) {
creds, ok := a.provider.Systems[a.options.SystemName]
if !ok {
- return "", fmt.Errorf("system %s is not configured", a.options.SystemName)
+ return "", fmt.Errorf("auth0: system %s is not configured: %s", a.options.SystemName, reauthMessage)
} else if creds.AccessToken == "" {
- return "", fmt.Errorf("access token missing: %s", reauthMessage)
+ return "", fmt.Errorf("auth0: access token missing: %s", reauthMessage)
} else if scopesChanged(creds) {
- return "", fmt.Errorf("authentication scopes changed: %s", reauthMessage)
+ return "", fmt.Errorf("auth0: authentication scopes changed: %s", reauthMessage)
} else if isExpired(creds.ExpiresAt, accessTokenExpiry) {
// check if the stored access token is expired:
// use the refresh token to get a new access token:
@@ -142,7 +154,7 @@ func (a *Client) AccessToken() (string, error) {
}
resp, err := tr.Refresh(cancelOnInterrupt(), a.options.SystemName)
if err != nil {
- return "", fmt.Errorf("failed to renew access token: %w: %s", err, reauthMessage)
+ return "", fmt.Errorf("auth0: failed to renew access token: %w: %s", err, reauthMessage)
} else {
// persist the updated system with renewed access token
creds.AccessToken = resp.AccessToken
@@ -173,12 +185,6 @@ func scopesChanged(s Credentials) bool {
return false
}
-// HasCredentials returns true if this client has retrived credentials for the configured system.
-func (a *Client) HasCredentials() bool {
- _, ok := a.provider.Systems[a.options.SystemName]
- return ok
-}
-
// WriteCredentials writes given credentials to the configuration file.
func (a *Client) WriteCredentials(credentials Credentials) error {
if a.provider.Systems == nil {
@@ -186,7 +192,7 @@ func (a *Client) WriteCredentials(credentials Credentials) error {
}
a.provider.Systems[a.options.SystemName] = credentials
if err := writeConfig(a.provider, a.options.ConfigPath); err != nil {
- return fmt.Errorf("failed to write config: %w", err)
+ return fmt.Errorf("auth0: failed to write config: %w", err)
}
return nil
}
@@ -195,11 +201,11 @@ func (a *Client) WriteCredentials(credentials Credentials) error {
func (a *Client) RemoveCredentials() error {
tr := &auth.TokenRetriever{Secrets: &auth.Keyring{}}
if err := tr.Delete(a.options.SystemName); err != nil {
- return fmt.Errorf("failed to remove system %s from secret storage: %w", a.options.SystemName, err)
+ return fmt.Errorf("auth0: failed to remove system %s from secret storage: %w", a.options.SystemName, err)
}
delete(a.provider.Systems, a.options.SystemName)
if err := writeConfig(a.provider, a.options.ConfigPath); err != nil {
- return fmt.Errorf("failed to write config: %w", err)
+ return fmt.Errorf("auth0: failed to write config: %w", err)
}
return nil
}
diff --git a/client/go/internal/cli/auth/zts/zts.go b/client/go/internal/cli/auth/zts/zts.go
index caa2d03367d..2c66ff13e8b 100644
--- a/client/go/internal/cli/auth/zts/zts.go
+++ b/client/go/internal/cli/auth/zts/zts.go
@@ -1,7 +1,6 @@
package zts
import (
- "crypto/tls"
"encoding/json"
"fmt"
"net/http"
@@ -18,26 +17,39 @@ const DefaultURL = "https://zts.athenz.ouroath.com:4443"
type Client struct {
client util.HTTPClient
tokenURL *url.URL
+ domain string
}
// NewClient creates a new client for an Athenz ZTS service located at serviceURL.
-func NewClient(client util.HTTPClient, serviceURL string) (*Client, error) {
+func NewClient(client util.HTTPClient, domain, serviceURL string) (*Client, error) {
tokenURL, err := url.Parse(serviceURL)
if err != nil {
return nil, err
}
tokenURL.Path = "/zts/v1/oauth2/token"
- return &Client{tokenURL: tokenURL, client: client}, nil
+ return &Client{tokenURL: tokenURL, client: client, domain: domain}, nil
}
-// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS.
-func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) {
- data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain)
+func (c *Client) Authenticate(request *http.Request) error {
+ accessToken, err := c.AccessToken()
+ if err != nil {
+ return err
+ }
+ if request.Header == nil {
+ request.Header = make(http.Header)
+ }
+ request.Header.Add("Authorization", "Bearer "+accessToken)
+ return nil
+}
+
+// AccessToken returns an access token within the domain configured in client c.
+func (c *Client) AccessToken() (string, error) {
+ // TODO(mpolden): This should cache and re-use tokens until expiry
+ data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", c.domain)
req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data))
if err != nil {
return "", err
}
- util.SetCertificates(c.client, []tls.Certificate{certificate})
response, err := c.client.Do(req, 10*time.Second)
if err != nil {
return "", err
@@ -45,7 +57,7 @@ func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
- return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String())
+ return "", fmt.Errorf("zts: got status %d from %s", response.StatusCode, c.tokenURL.String())
}
var ztsResponse struct {
AccessToken string `json:"access_token"`
diff --git a/client/go/internal/cli/auth/zts/zts_test.go b/client/go/internal/cli/auth/zts/zts_test.go
index d0cc7ea9f9d..1c75a94ee03 100644
--- a/client/go/internal/cli/auth/zts/zts_test.go
+++ b/client/go/internal/cli/auth/zts/zts_test.go
@@ -1,7 +1,6 @@
package zts
import (
- "crypto/tls"
"testing"
"github.com/vespa-engine/vespa/client/go/internal/mock"
@@ -9,17 +8,17 @@ import (
func TestAccessToken(t *testing.T) {
httpClient := mock.HTTPClient{}
- client, err := NewClient(&httpClient, "http://example.com")
+ client, err := NewClient(&httpClient, "vespa.vespa", "http://example.com")
if err != nil {
t.Fatal(err)
}
httpClient.NextResponseString(400, `{"message": "bad request"}`)
- _, err = client.AccessToken("vespa.vespa", tls.Certificate{})
+ _, err = client.AccessToken()
if err == nil {
t.Fatal("want error for non-ok response status")
}
httpClient.NextResponseString(200, `{"access_token": "foo bar"}`)
- token, err := client.AccessToken("vespa.vespa", tls.Certificate{})
+ token, err := client.AccessToken()
if err != nil {
t.Fatal(err)
}
diff --git a/client/go/internal/cli/cmd/cert.go b/client/go/internal/cli/cmd/cert.go
index 7f79a9db358..48bad974c3f 100644
--- a/client/go/internal/cli/cmd/cert.go
+++ b/client/go/internal/cli/cmd/cert.go
@@ -34,13 +34,18 @@ package specified as an argument to this command (default '.').
It's possible to override the private key and certificate used through
environment variables. This can be useful in continuous integration systems.
-Example of setting the certificate and key in-line:
+It's also possible override the CA certificate which can be useful when using self-signed certificates with a
+self-hosted Vespa service. See https://docs.vespa.ai/en/mtls.html for more information.
+Example of setting the CA certificate, certificate and key in-line:
+
+ export VESPA_CLI_DATA_PLANE_CA_CERT="my CA cert"
export VESPA_CLI_DATA_PLANE_CERT="my cert"
export VESPA_CLI_DATA_PLANE_KEY="my private key"
-Example of loading certificate and key from custom paths:
+Example of loading CA certificate, certificate and key from custom paths:
+ export VESPA_CLI_DATA_PLANE_CA_CERT_FILE=/path/to/cacert
export VESPA_CLI_DATA_PLANE_CERT_FILE=/path/to/cert
export VESPA_CLI_DATA_PLANE_KEY_FILE=/path/to/key
diff --git a/client/go/internal/cli/cmd/config.go b/client/go/internal/cli/cmd/config.go
index 2d32c454842..e2132814386 100644
--- a/client/go/internal/cli/cmd/config.go
+++ b/client/go/internal/cli/cmd/config.go
@@ -19,7 +19,6 @@ import (
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
- "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/cli/config"
"github.com/vespa-engine/vespa/client/go/internal/vespa"
)
@@ -250,9 +249,10 @@ type Config struct {
}
type KeyPair struct {
- KeyPair tls.Certificate
- CertificateFile string
- PrivateKeyFile string
+ KeyPair tls.Certificate
+ RootCertificates []byte
+ CertificateFile string
+ PrivateKeyFile string
}
func loadConfig(environment map[string]string, flags map[string]*pflag.Flag) (*Config, error) {
@@ -392,6 +392,10 @@ func (c *Config) deploymentIn(system vespa.System) (vespa.Deployment, error) {
return vespa.Deployment{System: system, Application: app, Zone: zone}, nil
}
+func (c *Config) caCertificatePath() string {
+ return c.environment["VESPA_CLI_DATA_PLANE_CA_CERT_FILE"]
+}
+
func (c *Config) certificatePath(app vespa.ApplicationID, targetType string) (string, error) {
if override, ok := c.environment["VESPA_CLI_DATA_PLANE_CERT_FILE"]; ok {
return override, nil
@@ -412,50 +416,68 @@ func (c *Config) privateKeyPath(app vespa.ApplicationID, targetType string) (str
return c.applicationFilePath(app, "data-plane-private-key.pem")
}
-func (c *Config) x509KeyPair(app vespa.ApplicationID, targetType string) (KeyPair, error) {
+func (c *Config) readTLSOptions(app vespa.ApplicationID, targetType string) (vespa.TLSOptions, error) {
+ _, trustAll := c.environment["VESPA_CLI_DATA_PLANE_TRUST_ALL"]
cert, certOk := c.environment["VESPA_CLI_DATA_PLANE_CERT"]
key, keyOk := c.environment["VESPA_CLI_DATA_PLANE_KEY"]
- var (
- kp tls.Certificate
- err error
- certFile string
- keyFile string
- )
+ caCertText, caCertOk := c.environment["VESPA_CLI_DATA_PLANE_CA_CERT"]
+ options := vespa.TLSOptions{TrustAll: trustAll}
+ // CA certificate
+ if caCertOk {
+ options.CACertificate = []byte(caCertText)
+ } else {
+ caCertFile := c.caCertificatePath()
+ if caCertFile != "" {
+ b, err := os.ReadFile(caCertFile)
+ if err != nil {
+ return options, err
+ }
+ options.CACertificate = b
+ options.CACertificateFile = caCertFile
+ }
+ }
+ // Certificate and private key
if certOk && keyOk {
- // Use key pair from environment
- kp, err = tls.X509KeyPair([]byte(cert), []byte(key))
+ kp, err := tls.X509KeyPair([]byte(cert), []byte(key))
+ if err != nil {
+ return vespa.TLSOptions{}, err
+ }
+ options.KeyPair = []tls.Certificate{kp}
} else {
- keyFile, err = c.privateKeyPath(app, targetType)
+ keyFile, err := c.privateKeyPath(app, targetType)
if err != nil {
- return KeyPair{}, err
+ return vespa.TLSOptions{}, err
}
- certFile, err = c.certificatePath(app, targetType)
+ certFile, err := c.certificatePath(app, targetType)
if err != nil {
- return KeyPair{}, err
+ return vespa.TLSOptions{}, err
+ }
+ kp, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err == nil {
+ options.KeyPair = []tls.Certificate{kp}
+ options.PrivateKeyFile = keyFile
+ options.CertificateFile = certFile
+ } else if err != nil && !os.IsNotExist(err) {
+ return vespa.TLSOptions{}, err
}
- kp, err = tls.LoadX509KeyPair(certFile, keyFile)
- }
- if err != nil {
- return KeyPair{}, err
}
- if targetType == vespa.TargetHosted {
- cert, err := x509.ParseCertificate(kp.Certificate[0])
+ if options.KeyPair != nil {
+ cert, err := x509.ParseCertificate(options.KeyPair[0].Certificate[0])
if err != nil {
- return KeyPair{}, err
+ return vespa.TLSOptions{}, err
}
now := time.Now()
expiredAt := cert.NotAfter
if expiredAt.Before(now) {
delta := now.Sub(expiredAt).Truncate(time.Second)
- return KeyPair{}, fmt.Errorf("certificate %s expired at %s (%s ago)", certFile, cert.NotAfter, delta)
+ source := options.CertificateFile
+ if source == "" {
+ source = "environment"
+ }
+ return vespa.TLSOptions{}, fmt.Errorf("certificate in %s expired at %s (%s ago)", source, cert.NotAfter, delta)
}
- return KeyPair{KeyPair: kp, CertificateFile: certFile, PrivateKeyFile: keyFile}, nil
}
- return KeyPair{
- KeyPair: kp,
- CertificateFile: certFile,
- PrivateKeyFile: keyFile,
- }, nil
+ return options, nil
}
func (c *Config) apiKeyFileFromEnv() (string, bool) {
@@ -490,11 +512,10 @@ func (c *Config) readAPIKey(cli *CLI, system vespa.System, tenantName string) ([
return nil, nil // Vespa Cloud CI only talks to data plane and does not have an API key
}
if !cli.isCI() {
- client, err := cli.auth0Factory(cli.httpClient, auth0.Options{ConfigPath: c.authConfigPath(), SystemName: system.Name, SystemURL: system.URL})
- if err == nil && client.HasCredentials() {
- return nil, nil // use Auth0
+ if _, err := os.Stat(c.authConfigPath()); err == nil {
+ return nil, nil // We have auth config, so we should prefer Auth0 over API key
}
- cli.printWarning("Authenticating with API key. This is discouraged in non-CI environments", "Authenticate with 'vespa auth login'")
+ cli.printWarning("Authenticating with API key. This is discouraged in non-CI environments", "Authenticate with 'vespa auth login' instead")
}
return os.ReadFile(c.apiKeyPath(tenantName))
}
diff --git a/client/go/internal/cli/cmd/config_test.go b/client/go/internal/cli/cmd/config_test.go
index 458878b4356..66b65bf402b 100644
--- a/client/go/internal/cli/cmd/config_test.go
+++ b/client/go/internal/cli/cmd/config_test.go
@@ -2,15 +2,21 @@
package cmd
import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "math/big"
"os"
"path/filepath"
"testing"
+ "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/mock"
- "github.com/vespa-engine/vespa/client/go/internal/util"
"github.com/vespa-engine/vespa/client/go/internal/vespa"
)
@@ -166,7 +172,7 @@ func TestReadAPIKey(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, []byte("foo"), key)
- // Cloud CI does not read key from disk as it's not expected to have any
+ // Cloud CI never reads key from disk as it's not expected to have any
cli, _, _ = newTestCLI(t, "VESPA_CLI_CLOUD_CI=true")
key, err = cli.config.readAPIKey(cli, vespa.PublicSystem, "t1")
require.Nil(t, err)
@@ -186,12 +192,111 @@ func TestReadAPIKey(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, []byte("baz"), key)
- // Auth0 is preferred when configured
+ // Prefer Auth0 if we have auth config
cli, _, _ = newTestCLI(t)
- cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) {
- return &mockAuth0{hasCredentials: true}, nil
- }
+ require.Nil(t, os.WriteFile(filepath.Join(cli.config.homeDir, "auth.json"), []byte("foo"), 0600))
key, err = cli.config.readAPIKey(cli, vespa.PublicSystem, "t1")
require.Nil(t, err)
assert.Nil(t, key)
}
+
+func TestConfigReadTLSOptions(t *testing.T) {
+ app := vespa.ApplicationID{Tenant: "t1", Application: "a1", Instance: "i1"}
+ homeDir := t.TempDir()
+
+ // No environment variables, and no files on disk
+ assertTLSOptions(t, homeDir, app, vespa.TargetLocal, vespa.TLSOptions{})
+
+ // A single environment variable is set
+ assertTLSOptions(t, homeDir, app, vespa.TargetLocal, vespa.TLSOptions{TrustAll: true}, "VESPA_CLI_DATA_PLANE_TRUST_ALL=true")
+
+ // Key pair is provided in-line in environment variables
+ pemCert, pemKey, keyPair := createKeyPair(t)
+ assertTLSOptions(t, homeDir, app,
+ vespa.TargetLocal,
+ vespa.TLSOptions{
+ TrustAll: true,
+ CACertificate: []byte("cacert"),
+ KeyPair: []tls.Certificate{keyPair},
+ },
+ "VESPA_CLI_DATA_PLANE_TRUST_ALL=true",
+ "VESPA_CLI_DATA_PLANE_CA_CERT=cacert",
+ "VESPA_CLI_DATA_PLANE_CERT="+string(pemCert),
+ "VESPA_CLI_DATA_PLANE_KEY="+string(pemKey),
+ )
+
+ // Key pair is provided as file paths through environment variables
+ certFile := filepath.Join(homeDir, "cert")
+ keyFile := filepath.Join(homeDir, "key")
+ caCertFile := filepath.Join(homeDir, "cacert")
+ require.Nil(t, os.WriteFile(certFile, pemCert, 0600))
+ require.Nil(t, os.WriteFile(keyFile, pemKey, 0600))
+ require.Nil(t, os.WriteFile(caCertFile, []byte("cacert"), 0600))
+ assertTLSOptions(t, homeDir, app,
+ vespa.TargetLocal,
+ vespa.TLSOptions{
+ KeyPair: []tls.Certificate{keyPair},
+ CACertificate: []byte("cacert"),
+ CACertificateFile: caCertFile,
+ CertificateFile: certFile,
+ PrivateKeyFile: keyFile,
+ },
+ "VESPA_CLI_DATA_PLANE_CERT_FILE="+certFile,
+ "VESPA_CLI_DATA_PLANE_KEY_FILE="+keyFile,
+ "VESPA_CLI_DATA_PLANE_CA_CERT_FILE="+caCertFile,
+ )
+
+ // Key pair resides in default paths
+ defaultCertFile := filepath.Join(homeDir, app.String(), "data-plane-public-cert.pem")
+ defaultKeyFile := filepath.Join(homeDir, app.String(), "data-plane-private-key.pem")
+ require.Nil(t, os.WriteFile(defaultCertFile, pemCert, 0600))
+ require.Nil(t, os.WriteFile(defaultKeyFile, pemKey, 0600))
+ assertTLSOptions(t, homeDir, app,
+ vespa.TargetLocal,
+ vespa.TLSOptions{
+ KeyPair: []tls.Certificate{keyPair},
+ CertificateFile: defaultCertFile,
+ PrivateKeyFile: defaultKeyFile,
+ },
+ )
+}
+
+func assertTLSOptions(t *testing.T, homeDir string, app vespa.ApplicationID, target string, want vespa.TLSOptions, envVars ...string) {
+ t.Helper()
+ envVars = append(envVars, "VESPA_CLI_HOME="+homeDir)
+ cli, _, _ := newTestCLI(t, envVars...)
+ require.Nil(t, cli.Run("config", "set", "application", app.String()))
+ config, err := cli.config.readTLSOptions(app, vespa.TargetLocal)
+ require.Nil(t, err)
+ assert.Equal(t, want, config)
+}
+
+func createKeyPair(t *testing.T) ([]byte, []byte, tls.Certificate) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatal(err)
+ }
+ notBefore := time.Now()
+ notAfter := notBefore.Add(24 * time.Hour)
+ template := x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName: "example.com"},
+ NotBefore: notBefore,
+ NotAfter: notAfter,
+ }
+ certificateDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certificateDER})
+ pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyDER})
+ kp, err := tls.X509KeyPair(pemCert, pemKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return pemCert, pemKey, kp
+}
diff --git a/client/go/internal/cli/cmd/curl.go b/client/go/internal/cli/cmd/curl.go
index 8fcd1fa6ef7..3d5aaff24dc 100644
--- a/client/go/internal/cli/cmd/curl.go
+++ b/client/go/internal/cli/cmd/curl.go
@@ -4,7 +4,6 @@ package cmd
import (
"fmt"
"log"
- "net/http"
"os"
"strings"
@@ -54,6 +53,7 @@ $ vespa curl -- -v --data-urlencode "yql=select * from music where album contain
return err
}
case vespa.DocumentService, vespa.QueryService:
+ c.CaCertificate = service.TLSOptions.CACertificateFile
c.PrivateKey = service.TLSOptions.PrivateKeyFile
c.Certificate = service.TLSOptions.CertificateFile
default:
@@ -79,15 +79,7 @@ func addAccessToken(cmd *curl.Command, target vespa.Target) error {
if target.Type() != vespa.TargetCloud {
return nil
}
- req := http.Request{}
- if err := target.SignRequest(&req, ""); err != nil {
- return err
- }
- headerValue := req.Header.Get("Authorization")
- if headerValue == "" {
- return fmt.Errorf("no authorization header added when signing request")
- }
- cmd.Header("Authorization", headerValue)
+ cmd.Header("Authorization", "secret")
return nil
}
diff --git a/client/go/internal/cli/cmd/feed.go b/client/go/internal/cli/cmd/feed.go
index 895a22d2be5..f0f82dd80d1 100644
--- a/client/go/internal/cli/cmd/feed.go
+++ b/client/go/internal/cli/cmd/feed.go
@@ -14,16 +14,24 @@ import (
"github.com/vespa-engine/vespa/client/go/internal/vespa/document"
)
-func addFeedFlags(cmd *cobra.Command, verbose *bool, connections *int) {
- cmd.PersistentFlags().IntVarP(connections, "connections", "N", 8, "The number of connections to use")
- cmd.PersistentFlags().BoolVarP(verbose, "verbose", "v", false, "Verbose mode. Print errors as they happen")
+func addFeedFlags(cmd *cobra.Command, options *feedOptions) {
+ cmd.PersistentFlags().IntVar(&options.connections, "connections", 8, "The number of connections to use")
+ cmd.PersistentFlags().StringVar(&options.route, "route", "", "Target Vespa route for feed operations")
+ cmd.PersistentFlags().IntVar(&options.traceLevel, "trace", 0, "The trace level of network traffic. 0 to disable")
+ cmd.PersistentFlags().IntVar(&options.timeoutSecs, "timeout", 0, "Feed operation timeout in seconds. 0 to disable")
+ cmd.PersistentFlags().BoolVar(&options.verbose, "verbose", false, "Verbose mode. Print successful operations in addition to errors")
+}
+
+type feedOptions struct {
+ connections int
+ route string
+ verbose bool
+ traceLevel int
+ timeoutSecs int
}
func newFeedCmd(cli *CLI) *cobra.Command {
- var (
- verbose bool
- connections int
- )
+ var options feedOptions
cmd := &cobra.Command{
Use: "feed FILE",
Short: "Feed documents to a Vespa cluster",
@@ -56,10 +64,10 @@ $ cat documents.jsonl | vespa feed -
defer f.Close()
r = f
}
- return feed(r, cli, verbose, connections)
+ return feed(r, cli, options)
},
}
- addFeedFlags(cmd, &verbose, &connections)
+ addFeedFlags(cmd, &options)
return cmd
}
@@ -67,29 +75,29 @@ func createServiceClients(service *vespa.Service, n int) []util.HTTPClient {
clients := make([]util.HTTPClient, 0, n)
for i := 0; i < n; i++ {
client := service.Client().Clone()
- util.ForceHTTP2(client, service.TLSOptions.KeyPair) // Feeding should always use HTTP/2
+ // Feeding should always use HTTP/2
+ util.ForceHTTP2(client, service.TLSOptions.KeyPair, service.TLSOptions.CACertificate, service.TLSOptions.TrustAll)
clients = append(clients, client)
}
return clients
}
-func feed(r io.Reader, cli *CLI, verbose bool, connections int) error {
+func feed(r io.Reader, cli *CLI, options feedOptions) error {
service, err := documentService(cli)
if err != nil {
return err
}
- clients := createServiceClients(service, connections)
+ clients := createServiceClients(service, options.connections)
client := document.NewClient(document.ClientOptions{
- BaseURL: service.BaseURL,
+ Timeout: time.Duration(options.timeoutSecs) * time.Second,
+ Route: options.route,
+ TraceLevel: options.traceLevel,
+ BaseURL: service.BaseURL,
}, clients)
- throttler := document.NewThrottler(connections)
+ throttler := document.NewThrottler(options.connections)
// TODO(mpolden): Make doom duration configurable
circuitBreaker := document.NewCircuitBreaker(10*time.Second, 0)
- errWriter := io.Discard
- if verbose {
- errWriter = cli.Stderr
- }
- dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, errWriter)
+ dispatcher := document.NewDispatcher(client, throttler, circuitBreaker, cli.Stderr, options.verbose)
dec := document.NewDecoder(r)
start := cli.now()
diff --git a/client/go/internal/cli/cmd/root.go b/client/go/internal/cli/cmd/root.go
index 360af9d0dcf..88d43411983 100644
--- a/client/go/internal/cli/cmd/root.go
+++ b/client/go/internal/cli/cmd/root.go
@@ -2,7 +2,6 @@
package cmd
import (
- "crypto/tls"
"encoding/json"
"fmt"
"io"
@@ -88,18 +87,9 @@ func (c *execSubprocess) Run(name string, args ...string) ([]byte, error) {
return exec.Command(name, args...).Output()
}
-type ztsClient interface {
- AccessToken(domain string, certficiate tls.Certificate) (string, error)
-}
-
-type auth0Client interface {
- AccessToken() (string, error)
- HasCredentials() bool
-}
-
-type auth0Factory func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error)
+type auth0Factory func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error)
-type ztsFactory func(httpClient util.HTTPClient, url string) (ztsClient, error)
+type ztsFactory func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error)
// New creates the Vespa CLI, writing output to stdout and stderr, and reading environment variables from environment.
func New(stdout, stderr io.Writer, environment []string) (*CLI, error) {
@@ -143,11 +133,11 @@ For detailed description of flags and configuration, see 'vespa help config'.
httpClient: util.CreateClient(time.Second * 10),
exec: &execSubprocess{},
now: time.Now,
- auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) {
+ auth0Factory: func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) {
return auth0.NewClient(httpClient, options)
},
- ztsFactory: func(httpClient util.HTTPClient, url string) (ztsClient, error) {
- return zts.NewClient(httpClient, url)
+ ztsFactory: func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) {
+ return zts.NewClient(httpClient, domain, url)
},
}
cli.isTerminal = func() bool { return isTerminal(cli.Stdout) && isTerminal(cli.Stderr) }
@@ -321,16 +311,34 @@ func (c *CLI) createTarget(opts targetOptions) (vespa.Target, error) {
if err != nil {
return nil, err
}
+ customURL := ""
if strings.HasPrefix(targetType, "http") {
- return vespa.CustomTarget(c.httpClient, targetType), nil
+ customURL = targetType
+ targetType = vespa.TargetCustom
}
switch targetType {
- case vespa.TargetLocal:
- return vespa.LocalTarget(c.httpClient), nil
+ case vespa.TargetLocal, vespa.TargetCustom:
+ return c.createCustomTarget(targetType, customURL)
case vespa.TargetCloud, vespa.TargetHosted:
return c.createCloudTarget(targetType, opts)
+ default:
+ return nil, errHint(fmt.Errorf("invalid target: %s", targetType), "Valid targets are 'local', 'cloud', 'hosted' or an URL")
+ }
+}
+
+func (c *CLI) createCustomTarget(targetType, customURL string) (vespa.Target, error) {
+ tlsOptions, err := c.config.readTLSOptions(vespa.DefaultApplication, targetType)
+ if err != nil {
+ return nil, err
+ }
+ switch targetType {
+ case vespa.TargetLocal:
+ return vespa.LocalTarget(c.httpClient, tlsOptions), nil
+ case vespa.TargetCustom:
+ return vespa.CustomTarget(c.httpClient, customURL, tlsOptions), nil
+ default:
+ return nil, fmt.Errorf("invalid custom target: %s", targetType)
}
- return nil, errHint(fmt.Errorf("invalid target: %s", targetType), "Valid targets are 'local', 'cloud', 'hosted' or an URL")
}
func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Target, error) {
@@ -347,48 +355,53 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta
return nil, err
}
var (
- apiKey []byte
- authConfigPath string
+ apiAuth vespa.Authenticator
+ deploymentAuth vespa.Authenticator
apiTLSOptions vespa.TLSOptions
deploymentTLSOptions vespa.TLSOptions
)
switch targetType {
case vespa.TargetCloud:
- apiKey, err = c.config.readAPIKey(c, system, deployment.Application.Tenant)
+ apiKey, err := c.config.readAPIKey(c, system, deployment.Application.Tenant)
if err != nil {
return nil, err
}
- authConfigPath = c.config.authConfigPath()
+ if apiKey == nil {
+ authConfigPath := c.config.authConfigPath()
+ auth0, err := c.auth0Factory(c.httpClient, auth0.Options{ConfigPath: authConfigPath, SystemName: system.Name, SystemURL: system.URL})
+ if err != nil {
+ return nil, err
+ }
+ apiAuth = auth0
+ } else {
+ apiAuth = vespa.NewRequestSigner(deployment.Application.SerializedForm(), apiKey)
+ }
deploymentTLSOptions = vespa.TLSOptions{}
if !opts.noCertificate {
- kp, err := c.config.x509KeyPair(deployment.Application, targetType)
+ kp, err := c.config.readTLSOptions(deployment.Application, targetType)
if err != nil {
- return nil, errHint(err, "Deployment to cloud requires a certificate. Try 'vespa auth cert'")
- }
- deploymentTLSOptions = vespa.TLSOptions{
- KeyPair: []tls.Certificate{kp.KeyPair},
- CertificateFile: kp.CertificateFile,
- PrivateKeyFile: kp.PrivateKeyFile,
+ return nil, errHint(err, "Deployment to cloud requires a certificate", "Try 'vespa auth cert' to create a self-signed certificate")
}
+ deploymentTLSOptions = kp
}
case vespa.TargetHosted:
- kp, err := c.config.x509KeyPair(deployment.Application, targetType)
+ kp, err := c.config.readTLSOptions(deployment.Application, targetType)
if err != nil {
return nil, errHint(err, "Deployment to hosted requires an Athenz certificate", "Try renewing certificate with 'athenz-user-cert'")
}
- apiTLSOptions = vespa.TLSOptions{
- KeyPair: []tls.Certificate{kp.KeyPair},
- CertificateFile: kp.CertificateFile,
- PrivateKeyFile: kp.PrivateKeyFile,
+ zts, err := c.ztsFactory(c.httpClient, system.AthenzDomain, zts.DefaultURL)
+ if err != nil {
+ return nil, err
}
- deploymentTLSOptions = apiTLSOptions
+ deploymentAuth = zts
+ apiTLSOptions = kp
+ deploymentTLSOptions = kp
default:
return nil, fmt.Errorf("invalid cloud target: %s", targetType)
}
apiOptions := vespa.APIOptions{
System: system,
TLSOptions: apiTLSOptions,
- APIKey: apiKey,
}
deploymentOptions := vespa.CloudDeploymentOptions{
Deployment: deployment,
@@ -403,15 +416,7 @@ func (c *CLI) createCloudTarget(targetType string, opts targetOptions) (vespa.Ta
Writer: c.Stdout,
Level: vespa.LogLevel(logLevel),
}
- auth0, err := c.auth0Factory(c.httpClient, auth0.Options{ConfigPath: authConfigPath, SystemName: apiOptions.System.Name, SystemURL: apiOptions.System.URL})
- if err != nil {
- return nil, err
- }
- zts, err := c.ztsFactory(c.httpClient, zts.DefaultURL)
- if err != nil {
- return nil, err
- }
- return vespa.CloudTarget(c.httpClient, zts, auth0, apiOptions, deploymentOptions, logOptions)
+ return vespa.CloudTarget(c.httpClient, apiAuth, deploymentAuth, apiOptions, deploymentOptions, logOptions)
}
// system returns the appropiate system for the target configured in this CLI.
@@ -460,7 +465,6 @@ func (c *CLI) createDeploymentOptions(pkg vespa.ApplicationPackage, target vespa
ApplicationPackage: pkg,
Target: target,
Timeout: timeout,
- HTTPClient: c.httpClient,
}, nil
}
diff --git a/client/go/internal/cli/cmd/test.go b/client/go/internal/cli/cmd/test.go
index 05633b1135e..8c4501e2870 100644
--- a/client/go/internal/cli/cmd/test.go
+++ b/client/go/internal/cli/cmd/test.go
@@ -263,7 +263,7 @@ func verify(step step, defaultCluster string, defaultParameters map[string]strin
var response *http.Response
if externalEndpoint {
- util.SetCertificates(context.cli.httpClient, []tls.Certificate{})
+ util.ConfigureTLS(context.cli.httpClient, []tls.Certificate{}, nil, false)
response, err = context.cli.httpClient.Do(request, 60*time.Second)
} else {
response, err = service.Do(request, 600*time.Second) // Vespa should provide a response within the given request timeout
diff --git a/client/go/internal/cli/cmd/testutil_test.go b/client/go/internal/cli/cmd/testutil_test.go
index 61f8dab2264..492e40d8855 100644
--- a/client/go/internal/cli/cmd/testutil_test.go
+++ b/client/go/internal/cli/cmd/testutil_test.go
@@ -3,13 +3,14 @@ package cmd
import (
"bytes"
- "crypto/tls"
+ "net/http"
"path/filepath"
"testing"
"github.com/vespa-engine/vespa/client/go/internal/cli/auth/auth0"
"github.com/vespa-engine/vespa/client/go/internal/mock"
"github.com/vespa-engine/vespa/client/go/internal/util"
+ "github.com/vespa-engine/vespa/client/go/internal/vespa"
)
func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Buffer) {
@@ -29,21 +30,15 @@ func newTestCLI(t *testing.T, envVars ...string) (*CLI, *bytes.Buffer, *bytes.Bu
httpClient := &mock.HTTPClient{}
cli.httpClient = httpClient
cli.exec = &mock.Exec{}
- cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (auth0Client, error) {
- return &mockAuth0{}, nil
+ cli.auth0Factory = func(httpClient util.HTTPClient, options auth0.Options) (vespa.Authenticator, error) {
+ return &mockAuthenticator{}, nil
}
- cli.ztsFactory = func(httpClient util.HTTPClient, url string) (ztsClient, error) {
- return &mockZTS{}, nil
+ cli.ztsFactory = func(httpClient util.HTTPClient, domain, url string) (vespa.Authenticator, error) {
+ return &mockAuthenticator{}, nil
}
return cli, &stdout, &stderr
}
-type mockZTS struct{}
+type mockAuthenticator struct{}
-func (z *mockZTS) AccessToken(domain string, cert tls.Certificate) (string, error) { return "", nil }
-
-type mockAuth0 struct{ hasCredentials bool }
-
-func (a *mockAuth0) AccessToken() (string, error) { return "", nil }
-
-func (a *mockAuth0) HasCredentials() bool { return a.hasCredentials }
+func (a *mockAuthenticator) Authenticate(request *http.Request) error { return nil }
diff --git a/client/go/internal/util/http.go b/client/go/internal/util/http.go
index dcf05ed3a14..8a67b24dffb 100644
--- a/client/go/internal/util/http.go
+++ b/client/go/internal/util/http.go
@@ -4,6 +4,7 @@ package util
import (
"context"
"crypto/tls"
+ "crypto/x509"
"fmt"
"net"
"net/http"
@@ -35,7 +36,7 @@ func (c *defaultHTTPClient) Do(request *http.Request, timeout time.Duration) (re
func (c *defaultHTTPClient) Clone() HTTPClient { return CreateClient(c.client.Timeout) }
-func SetCertificates(client HTTPClient, certificates []tls.Certificate) {
+func ConfigureTLS(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) {
c, ok := client.(*defaultHTTPClient)
if !ok {
return
@@ -43,8 +44,14 @@ func SetCertificates(client HTTPClient, certificates []tls.Certificate) {
var tlsConfig *tls.Config = nil
if certificates != nil {
tlsConfig = &tls.Config{
- Certificates: certificates,
- MinVersion: tls.VersionTLS12,
+ Certificates: certificates,
+ MinVersion: tls.VersionTLS12,
+ InsecureSkipVerify: trustAll,
+ }
+ if caCertificate != nil {
+ certs := x509.NewCertPool()
+ certs.AppendCertsFromPEM(caCertificate)
+ tlsConfig.RootCAs = certs
}
}
if tr, ok := c.client.Transport.(*http.Transport); ok {
@@ -56,19 +63,13 @@ func SetCertificates(client HTTPClient, certificates []tls.Certificate) {
}
}
-func ForceHTTP2(client HTTPClient, certificates []tls.Certificate) {
+func ForceHTTP2(client HTTPClient, certificates []tls.Certificate, caCertificate []byte, trustAll bool) {
c, ok := client.(*defaultHTTPClient)
if !ok {
return
}
- var tlsConfig *tls.Config = nil
var dialFunc func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error)
- if certificates != nil {
- tlsConfig = &tls.Config{
- Certificates: certificates,
- MinVersion: tls.VersionTLS12,
- }
- } else {
+ if certificates == nil {
// No certificate, so force H2C (HTTP/2 over clear-text) by using a non-TLS Dialer
dialer := net.Dialer{}
dialFunc = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
@@ -80,10 +81,10 @@ func ForceHTTP2(client HTTPClient, certificates []tls.Certificate) {
// https://github.com/golang/go/issues/16582
// https://github.com/golang/go/issues/22091
c.client.Transport = &http2.Transport{
- AllowHTTP: true,
- TLSClientConfig: tlsConfig,
- DialTLSContext: dialFunc,
+ AllowHTTP: true,
+ DialTLSContext: dialFunc,
}
+ ConfigureTLS(client, certificates, caCertificate, trustAll)
}
func CreateClient(timeout time.Duration) HTTPClient {
diff --git a/client/go/internal/vespa/crypto.go b/client/go/internal/vespa/crypto.go
index 9621d0c1180..5e273538869 100644
--- a/client/go/internal/vespa/crypto.go
+++ b/client/go/internal/vespa/crypto.go
@@ -111,6 +111,8 @@ func NewRequestSigner(keyID string, pemPrivateKey []byte) *RequestSigner {
}
}
+func (rs *RequestSigner) Authenticate(request *http.Request) error { return rs.SignRequest(request) }
+
// SignRequest signs the given HTTP request using the private key in rs
func (rs *RequestSigner) SignRequest(request *http.Request) error {
timestamp := rs.now().UTC().Format(time.RFC3339)
diff --git a/client/go/internal/vespa/deploy.go b/client/go/internal/vespa/deploy.go
index 687bfc46124..f633c8ed9ee 100644
--- a/client/go/internal/vespa/deploy.go
+++ b/client/go/internal/vespa/deploy.go
@@ -45,7 +45,6 @@ type DeploymentOptions struct {
ApplicationPackage ApplicationPackage
Timeout time.Duration
Version version.Version
- HTTPClient util.HTTPClient
}
type LogLinePrepareResponse struct {
@@ -130,7 +129,7 @@ func Prepare(deployment DeploymentOptions) (PrepareResult, error) {
return PrepareResult{}, err
}
serviceDescription := "Deploy service"
- response, err := deployment.HTTPClient.Do(req, time.Second*30)
+ response, err := deployServiceDo(req, time.Second*30, deployment)
if err != nil {
return PrepareResult{}, err
}
@@ -171,7 +170,7 @@ func Activate(sessionID int64, deployment DeploymentOptions) error {
return err
}
serviceDescription := "Deploy service"
- response, err := deployment.HTTPClient.Do(req, time.Second*30)
+ response, err := deployServiceDo(req, time.Second*30, deployment)
if err != nil {
return err
}
@@ -263,11 +262,7 @@ func Submit(opts DeploymentOptions) error {
}
request.Header.Set("Content-Type", writer.FormDataContentType())
serviceDescription := "Submit service"
- sigKeyId := opts.Target.Deployment().Application.SerializedForm()
- if err := opts.Target.SignRequest(request, sigKeyId); err != nil {
- return fmt.Errorf("failed to sign api request: %w", err)
- }
- response, err := opts.HTTPClient.Do(request, time.Minute*10)
+ response, err := deployServiceDo(request, time.Minute*10, opts)
if err != nil {
return err
}
@@ -275,6 +270,14 @@ func Submit(opts DeploymentOptions) error {
return checkResponse(request, response, serviceDescription)
}
+func deployServiceDo(request *http.Request, timeout time.Duration, opts DeploymentOptions) (*http.Response, error) {
+ s, err := opts.Target.Service(DeployService, 0, 0, "")
+ if err != nil {
+ return nil, err
+ }
+ return s.Do(request, timeout)
+}
+
func checkDeploymentOpts(opts DeploymentOptions) error {
if opts.Target.Type() == TargetCloud && !opts.ApplicationPackage.HasCertificate() {
return fmt.Errorf("%s: missing certificate in package", opts)
@@ -334,11 +337,6 @@ func uploadApplicationPackage(url *url.URL, opts DeploymentOptions) (PrepareResu
if err != nil {
return PrepareResult{}, err
}
-
- keyID := opts.Target.Deployment().Application.SerializedForm()
- if err := opts.Target.SignRequest(request, keyID); err != nil {
- return PrepareResult{}, err
- }
response, err := service.Do(request, time.Minute*10)
if err != nil {
return PrepareResult{}, err
diff --git a/client/go/internal/vespa/deploy_test.go b/client/go/internal/vespa/deploy_test.go
index 3e74e9ab3b6..da2604282c0 100644
--- a/client/go/internal/vespa/deploy_test.go
+++ b/client/go/internal/vespa/deploy_test.go
@@ -19,12 +19,11 @@ import (
func TestDeploy(t *testing.T) {
httpClient := mock.HTTPClient{}
- target := LocalTarget(&httpClient)
+ target := LocalTarget(&httpClient, TLSOptions{})
appDir, _ := mock.ApplicationPackageDir(t, false, false)
opts := DeploymentOptions{
Target: target,
ApplicationPackage: ApplicationPackage{Path: appDir},
- HTTPClient: &httpClient,
}
_, err := Deploy(opts)
assert.Nil(t, err)
@@ -47,7 +46,6 @@ func TestDeployCloud(t *testing.T) {
opts := DeploymentOptions{
Target: target,
ApplicationPackage: ApplicationPackage{Path: appDir},
- HTTPClient: &httpClient,
}
_, err := Deploy(opts)
require.Nil(t, err)
diff --git a/client/go/internal/vespa/document/dispatcher.go b/client/go/internal/vespa/document/dispatcher.go
index 838a7bc45ee..533ca7a0019 100644
--- a/client/go/internal/vespa/document/dispatcher.go
+++ b/client/go/internal/vespa/document/dispatcher.go
@@ -4,6 +4,7 @@ import (
"container/list"
"fmt"
"io"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -18,12 +19,15 @@ type Dispatcher struct {
circuitBreaker CircuitBreaker
stats Stats
- started bool
- ready chan Id
- results chan Result
+ started bool
+ ready chan Id
+ results chan Result
+ msgs chan string
+
inflight map[string]*documentGroup
inflightCount int64
- errWriter io.Writer
+ output io.Writer
+ verbose bool
mu sync.RWMutex
wg sync.WaitGroup
@@ -55,13 +59,14 @@ func (g *documentGroup) add(op documentOp, first bool) {
}
}
-func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, errWriter io.Writer) *Dispatcher {
+func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, output io.Writer, verbose bool) *Dispatcher {
d := &Dispatcher{
feeder: feeder,
throttler: throttler,
circuitBreaker: breaker,
inflight: make(map[string]*documentGroup),
- errWriter: errWriter,
+ output: output,
+ verbose: verbose,
}
d.start()
return d
@@ -69,8 +74,6 @@ func NewDispatcher(feeder Feeder, throttler Throttler, breaker CircuitBreaker, e
func (d *Dispatcher) sendDocumentIn(group *documentGroup) {
group.mu.Lock()
- defer group.mu.Unlock()
- defer d.releaseSlot()
first := group.ops.Front()
if first == nil {
panic("sending from empty document group, this should not happen")
@@ -79,6 +82,8 @@ func (d *Dispatcher) sendDocumentIn(group *documentGroup) {
op.attempts++
result := d.feeder.Send(op.document)
d.results <- result
+ d.releaseSlot()
+ group.mu.Unlock()
if d.shouldRetry(op, result) {
d.enqueue(op)
}
@@ -86,29 +91,35 @@ func (d *Dispatcher) sendDocumentIn(group *documentGroup) {
func (d *Dispatcher) shouldRetry(op documentOp, result Result) bool {
if result.HTTPStatus/100 == 2 || result.HTTPStatus == 404 || result.HTTPStatus == 412 {
+ if d.verbose {
+ d.msgs <- fmt.Sprintf("feed: successfully fed %s with status %d", op.document.Id, result.HTTPStatus)
+ }
d.throttler.Success()
d.circuitBreaker.Success()
return false
}
if result.HTTPStatus == 429 || result.HTTPStatus == 503 {
- fmt.Fprintf(d.errWriter, "feed: %s was throttled with status %d: retrying\n", op.document, result.HTTPStatus)
+ d.msgs <- fmt.Sprintf("feed: %s was throttled with status %d: retrying\n", op.document, result.HTTPStatus)
d.throttler.Throttled(atomic.LoadInt64(&d.inflightCount))
return true
}
if result.Err != nil || result.HTTPStatus == 500 || result.HTTPStatus == 502 || result.HTTPStatus == 504 {
retry := op.attempts <= maxAttempts
- msg := "feed: " + op.document.String() + " failed with "
+ var msg strings.Builder
+ msg.WriteString("feed: ")
+ msg.WriteString(op.document.String())
if result.Err != nil {
- msg += "error " + result.Err.Error()
+ msg.WriteString("error ")
+ msg.WriteString(result.Err.Error())
} else {
- msg += fmt.Sprintf("status %d", result.HTTPStatus)
+ msg.WriteString(fmt.Sprintf("status %d", result.HTTPStatus))
}
if retry {
- msg += ": retrying"
+ msg.WriteString(": retrying")
} else {
- msg += fmt.Sprintf(": giving up after %d attempts", maxAttempts)
+ msg.WriteString(fmt.Sprintf(": giving up after %d attempts", maxAttempts))
}
- fmt.Fprintln(d.errWriter, msg)
+ d.msgs <- msg.String()
d.circuitBreaker.Error(fmt.Errorf("request failed with status %d", result.HTTPStatus))
if retry {
return true
@@ -125,17 +136,22 @@ func (d *Dispatcher) start() {
}
d.ready = make(chan Id, 4096)
d.results = make(chan Result, 4096)
+ d.msgs = make(chan string, 4096)
d.started = true
d.wg.Add(1)
go func() {
defer d.wg.Done()
d.readDocuments()
}()
- d.resultWg.Add(1)
+ d.resultWg.Add(2)
go func() {
defer d.resultWg.Done()
d.readResults()
}()
+ go func() {
+ defer d.resultWg.Done()
+ d.readMessages()
+ }()
}
func (d *Dispatcher) readDocuments() {
@@ -157,15 +173,22 @@ func (d *Dispatcher) readResults() {
}
}
+func (d *Dispatcher) readMessages() {
+ for msg := range d.msgs {
+ fmt.Fprintln(d.output, msg)
+ }
+}
+
func (d *Dispatcher) enqueue(op documentOp) error {
d.mu.Lock()
if !d.started {
return fmt.Errorf("dispatcher is closed")
}
- group, ok := d.inflight[op.document.Id.String()]
+ key := op.document.Id.String()
+ group, ok := d.inflight[key]
if !ok {
group = &documentGroup{}
- d.inflight[op.document.Id.String()] = group
+ d.inflight[key] = group
}
d.mu.Unlock()
group.add(op, op.attempts > 0)
@@ -188,25 +211,26 @@ func (d *Dispatcher) acquireSlot() {
func (d *Dispatcher) releaseSlot() { atomic.AddInt64(&d.inflightCount, -1) }
-func closeAndWait[T any](ch chan T, wg *sync.WaitGroup, d *Dispatcher, markClosed bool) {
- d.mu.Lock()
- if d.started {
- close(ch)
- if markClosed {
- d.started = false
- }
- }
- d.mu.Unlock()
- wg.Wait()
-}
-
func (d *Dispatcher) Enqueue(doc Document) error { return d.enqueue(documentOp{document: doc}) }
func (d *Dispatcher) Stats() Stats { return d.stats }
// Close closes the dispatcher and waits for all inflight operations to complete.
func (d *Dispatcher) Close() error {
- closeAndWait(d.ready, &d.wg, d, false)
- closeAndWait(d.results, &d.resultWg, d, true)
+ d.mu.Lock()
+ if d.started {
+ close(d.ready)
+ }
+ d.mu.Unlock()
+ d.wg.Wait() // Wait for inflight operations to complete
+
+ d.mu.Lock()
+ if d.started {
+ close(d.results)
+ close(d.msgs)
+ d.started = false
+ }
+ d.mu.Unlock()
+ d.resultWg.Wait() // Wait for results
return nil
}
diff --git a/client/go/internal/vespa/document/dispatcher_test.go b/client/go/internal/vespa/document/dispatcher_test.go
index 80bc5f603ae..d066f5bc9ae 100644
--- a/client/go/internal/vespa/document/dispatcher_test.go
+++ b/client/go/internal/vespa/document/dispatcher_test.go
@@ -41,7 +41,7 @@ func TestDispatcher(t *testing.T) {
clock := &manualClock{tick: time.Second}
throttler := newThrottler(8, clock.now)
breaker := NewCircuitBreaker(time.Second, 0)
- dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard)
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
docs := []Document{
{Id: mustParseId("id:ns:type::doc1"), Operation: OperationPut, Body: []byte(`{"fields":{"foo": "123"}}`)},
{Id: mustParseId("id:ns:type::doc2"), Operation: OperationPut, Body: []byte(`{"fields":{"bar": "456"}}`)},
@@ -74,7 +74,7 @@ func TestDispatcherOrdering(t *testing.T) {
clock := &manualClock{tick: time.Second}
throttler := newThrottler(8, clock.now)
breaker := NewCircuitBreaker(time.Second, 0)
- dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard)
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
for _, d := range docs {
dispatcher.Enqueue(d)
}
@@ -110,7 +110,7 @@ func TestDispatcherOrderingWithFailures(t *testing.T) {
clock := &manualClock{tick: time.Second}
throttler := newThrottler(8, clock.now)
breaker := NewCircuitBreaker(time.Second, 0)
- dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard)
+ dispatcher := NewDispatcher(feeder, throttler, breaker, io.Discard, false)
for _, d := range docs {
dispatcher.Enqueue(d)
}
diff --git a/client/go/internal/vespa/document/http.go b/client/go/internal/vespa/document/http.go
index 588330a0574..1bcd7eff39e 100644
--- a/client/go/internal/vespa/document/http.go
+++ b/client/go/internal/vespa/document/http.go
@@ -29,7 +29,7 @@ type ClientOptions struct {
BaseURL string
Timeout time.Duration
Route string
- TraceLevel *int
+ TraceLevel int
}
type countingHTTPClient struct {
@@ -72,14 +72,18 @@ func NewClient(options ClientOptions, httpClients []util.HTTPClient) *Client {
func (c *Client) queryParams() url.Values {
params := url.Values{}
- if c.options.Timeout > 0 {
- params.Set("timeout", strconv.FormatInt(c.options.Timeout.Milliseconds(), 10)+"ms")
+ timeout := c.options.Timeout
+ if timeout == 0 {
+ timeout = 200 * time.Second
+ } else {
+ timeout = timeout*11/10 + 1000
}
+ params.Set("timeout", strconv.FormatInt(timeout.Milliseconds(), 10)+"ms")
if c.options.Route != "" {
params.Set("route", c.options.Route)
}
- if c.options.TraceLevel != nil {
- params.Set("tracelevel", strconv.Itoa(*c.options.TraceLevel))
+ if c.options.TraceLevel > 0 {
+ params.Set("tracelevel", strconv.Itoa(c.options.TraceLevel))
}
return params
}
@@ -166,7 +170,7 @@ func (c *Client) Send(document Document) Result {
}
defer resp.Body.Close()
elapsed := c.now().Sub(start)
- return c.resultWithResponse(resp, result, document, elapsed)
+ return resultWithResponse(resp, result, document, elapsed)
}
func resultWithErr(result Result, err error) Result {
@@ -176,7 +180,7 @@ func resultWithErr(result Result, err error) Result {
return result
}
-func (c *Client) resultWithResponse(resp *http.Response, result Result, document Document, elapsed time.Duration) Result {
+func resultWithResponse(resp *http.Response, result Result, document Document, elapsed time.Duration) Result {
result.HTTPStatus = resp.StatusCode
result.Stats.Responses++
result.Stats.ResponsesByCode = map[int]int64{resp.StatusCode: 1}
diff --git a/client/go/internal/vespa/document/http_test.go b/client/go/internal/vespa/document/http_test.go
index 43eaf1bfdf9..8f8394a5d4e 100644
--- a/client/go/internal/vespa/document/http_test.go
+++ b/client/go/internal/vespa/document/http_test.go
@@ -108,7 +108,7 @@ func TestClientSend(t *testing.T) {
if r.Method != http.MethodPut {
t.Errorf("got r.Method = %q, want %q", r.Method, http.MethodPut)
}
- wantURL := fmt.Sprintf("https://example.com:1337/document/v1/ns/type/docid/%s?create=true&timeout=5000ms", doc.Id.UserSpecific)
+ wantURL := fmt.Sprintf("https://example.com:1337/document/v1/ns/type/docid/%s?create=true&timeout=5500ms", doc.Id.UserSpecific)
if r.URL.String() != wantURL {
t.Errorf("got r.URL = %q, want %q", r.URL, wantURL)
}
diff --git a/client/go/internal/vespa/target.go b/client/go/internal/vespa/target.go
index bc936623bcb..9f3fd7f5c65 100644
--- a/client/go/internal/vespa/target.go
+++ b/client/go/internal/vespa/target.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
+ "sync"
"time"
"github.com/vespa-engine/vespa/client/go/internal/util"
@@ -17,7 +18,7 @@ const (
// A target for a local Vespa service
TargetLocal = "local"
- // A target for a custom URL
+ // A target for a Vespa service at a custom URL
TargetCustom = "custom"
// A Vespa Cloud target
@@ -38,13 +39,19 @@ const (
retryInterval = 2 * time.Second
)
+// Authenticator authenticates the given HTTP request.
+type Authenticator interface {
+ Authenticate(request *http.Request) error
+}
+
// Service represents a Vespa service.
type Service struct {
BaseURL string
Name string
TLSOptions TLSOptions
- zts zts
+ once sync.Once
+ auth Authenticator
httpClient util.HTTPClient
}
@@ -65,19 +72,19 @@ type Target interface {
// PrintLog writes the logs of this deployment using given options to control output.
PrintLog(options LogOptions) error
- // SignRequest signs request with given keyID as required by the implementation of this target.
- SignRequest(request *http.Request, keyID string) error
-
// CheckVersion verifies whether clientVersion is compatible with this target.
CheckVersion(clientVersion version.Version) error
}
-// TLSOptions configures the client certificate to use for cloud API or service requests.
+// TLSOptions holds the client certificate to use for cloud API or service requests.
type TLSOptions struct {
- KeyPair []tls.Certificate
- CertificateFile string
- PrivateKeyFile string
- AthenzDomain string
+ CACertificate []byte
+ KeyPair []tls.Certificate
+ TrustAll bool
+
+ CACertificateFile string
+ CertificateFile string
+ PrivateKeyFile string
}
// LogOptions configures the log output to produce when writing log messages.
@@ -90,17 +97,15 @@ type LogOptions struct {
Level int
}
-// Do sends request to this service. Any required authentication happens automatically.
+// Do sends request to this service. Authentication of the request happens automatically.
func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) {
- if s.TLSOptions.AthenzDomain != "" && s.TLSOptions.KeyPair != nil {
- accessToken, err := s.zts.AccessToken(s.TLSOptions.AthenzDomain, s.TLSOptions.KeyPair[0])
- if err != nil {
+ s.once.Do(func() {
+ util.ConfigureTLS(s.httpClient, s.TLSOptions.KeyPair, s.TLSOptions.CACertificate, s.TLSOptions.TrustAll)
+ })
+ if s.auth != nil {
+ if err := s.auth.Authenticate(request); err != nil {
return nil, err
}
- if request.Header == nil {
- request.Header = make(http.Header)
- }
- request.Header.Add("Authorization", "Bearer "+accessToken)
}
return s.httpClient.Do(request, timeout)
}
@@ -118,7 +123,7 @@ func (s *Service) Wait(timeout time.Duration) (int, error) {
default:
return 0, fmt.Errorf("invalid service: %s", s.Name)
}
- return waitForOK(s.httpClient, url, s.TLSOptions.KeyPair, timeout)
+ return waitForOK(s, url, timeout)
}
func (s *Service) Description() string {
@@ -133,27 +138,40 @@ func (s *Service) Description() string {
return fmt.Sprintf("No description of service %s", s.Name)
}
-func isOK(status int) bool { return status/100 == 2 }
+func isOK(status int) (bool, error) {
+ class := status / 100
+ switch class {
+ case 2: // success
+ return true, nil
+ case 4: // client error
+ return false, fmt.Errorf("request failed with status %d", status)
+ default: // retry
+ return false, nil
+ }
+}
type responseFunc func(status int, response []byte) (bool, error)
type requestFunc func() *http.Request
-// waitForOK queries url and returns its status code. If the url returns a non-200 status code, it is repeatedly queried
+// waitForOK queries url and returns its status code. If response status is not 2xx or 4xx, it is repeatedly queried
// until timeout elapses.
-func waitForOK(client util.HTTPClient, url string, certificates []tls.Certificate, timeout time.Duration) (int, error) {
+func waitForOK(service *Service, url string, timeout time.Duration) (int, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return 0, err
}
- okFunc := func(status int, response []byte) (bool, error) { return isOK(status), nil }
- return wait(client, okFunc, func() *http.Request { return req }, certificates, timeout)
+ okFunc := func(status int, response []byte) (bool, error) {
+ ok, err := isOK(status)
+ if err != nil {
+ return false, fmt.Errorf("failed to query %s at %s: %w", service.Description(), url, err)
+ }
+ return ok, err
+ }
+ return wait(service, okFunc, func() *http.Request { return req }, timeout)
}
-func wait(client util.HTTPClient, fn responseFunc, reqFn requestFunc, certificates []tls.Certificate, timeout time.Duration) (int, error) {
- if certificates != nil {
- util.SetCertificates(client, certificates)
- }
+func wait(service *Service, fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) {
var (
httpErr error
response *http.Response
@@ -163,7 +181,7 @@ func wait(client util.HTTPClient, fn responseFunc, reqFn requestFunc, certificat
loopOnce := timeout == 0
for time.Now().Before(deadline) || loopOnce {
req := reqFn()
- response, httpErr = client.Do(req, 10*time.Second)
+ response, httpErr = service.Do(req, 10*time.Second)
if httpErr == nil {
statusCode = response.StatusCode
body, err := io.ReadAll(response.Body)
diff --git a/client/go/internal/vespa/target_cloud.go b/client/go/internal/vespa/target_cloud.go
index 1fb3edd78c5..928bb788494 100644
--- a/client/go/internal/vespa/target_cloud.go
+++ b/client/go/internal/vespa/target_cloud.go
@@ -2,7 +2,6 @@ package vespa
import (
"bytes"
- "crypto/tls"
"encoding/json"
"fmt"
"math"
@@ -35,8 +34,8 @@ type cloudTarget struct {
deploymentOptions CloudDeploymentOptions
logOptions LogOptions
httpClient util.HTTPClient
- zts zts
- auth0 auth0
+ apiAuth Authenticator
+ deploymentAuth Authenticator
}
type deploymentEndpoint struct {
@@ -62,23 +61,15 @@ type logMessage struct {
Message string `json:"message"`
}
-type zts interface {
- AccessToken(domain string, certficiate tls.Certificate) (string, error)
-}
-
-type auth0 interface {
- AccessToken() (string, error)
-}
-
// CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform.
-func CloudTarget(httpClient util.HTTPClient, ztsClient zts, auth0Client auth0, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) {
+func CloudTarget(httpClient util.HTTPClient, apiAuth Authenticator, deploymentAuth Authenticator, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) {
return &cloudTarget{
httpClient: httpClient,
apiOptions: apiOptions,
deploymentOptions: deploymentOptions,
logOptions: logOptions,
- zts: ztsClient,
- auth0: auth0Client,
+ apiAuth: apiAuth,
+ deploymentAuth: deploymentAuth,
}, nil
}
@@ -118,25 +109,25 @@ func (t *cloudTarget) IsCloud() bool { return true }
func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment }
func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) {
- var service *Service
switch name {
case DeployService:
- service = &Service{
+ service := &Service{
Name: name,
BaseURL: t.apiOptions.System.URL,
TLSOptions: t.apiOptions.TLSOptions,
- zts: t.zts,
httpClient: t.httpClient,
+ auth: t.apiAuth,
}
if timeout > 0 {
status, err := service.Wait(timeout)
if err != nil {
return nil, err
}
- if !isOK(status) {
+ if ok, _ := isOK(status); !ok {
return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL)
}
}
+ return service, nil
case QueryService, DocumentService:
if t.deploymentOptions.ClusterURLs == nil {
if err := t.waitForEndpoints(timeout, runID); err != nil {
@@ -147,38 +138,15 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c
if err != nil {
return nil, err
}
- t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain
- service = &Service{
+ return &Service{
Name: name,
BaseURL: url,
TLSOptions: t.deploymentOptions.TLSOptions,
- zts: t.zts,
httpClient: t.httpClient,
- }
-
+ auth: t.deploymentAuth,
+ }, nil
default:
return nil, fmt.Errorf("unknown service: %s", name)
-
- }
- if service.TLSOptions.KeyPair != nil {
- util.SetCertificates(service.httpClient, service.TLSOptions.KeyPair)
- }
- return service, nil
-}
-
-func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error {
- if t.apiOptions.System.IsPublic() {
- if t.apiOptions.APIKey != nil {
- signer := NewRequestSigner(keyID, t.apiOptions.APIKey)
- return signer.SignRequest(req)
- } else {
- return t.addAuth0AccessToken(req)
- }
- } else {
- if t.apiOptions.TLSOptions.KeyPair == nil {
- return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name)
- }
- return nil
}
}
@@ -190,7 +158,11 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error {
if err != nil {
return err
}
- response, err := t.httpClient.Do(req, 10*time.Second)
+ deployService, err := t.Service(DeployService, 0, 0, "")
+ if err != nil {
+ return err
+ }
+ response, err := deployService.Do(req, 10*time.Second)
if err != nil {
return err
}
@@ -212,18 +184,6 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error {
return nil
}
-func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error {
- accessToken, err := t.auth0.AccessToken()
- if err != nil {
- return err
- }
- if request.Header == nil {
- request.Header = make(http.Header)
- }
- request.Header.Set("Authorization", "Bearer "+accessToken)
- return nil
-}
-
func (t *cloudTarget) logsURL() string {
return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs",
t.apiOptions.System.URL,
@@ -246,11 +206,10 @@ func (t *cloudTarget) PrintLog(options LogOptions) error {
q.Set("to", strconv.FormatInt(toMillis, 10))
}
req.URL.RawQuery = q.Encode()
- t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm())
return req
}
logFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isCloudOK(status); !ok {
+ if ok, err := isOK(status); !ok {
return ok, err
}
logEntries, err := ReadLogEntries(bytes.NewReader(response))
@@ -275,10 +234,18 @@ func (t *cloudTarget) PrintLog(options LogOptions) error {
if options.Follow {
timeout = math.MaxInt64 // No timeout
}
- _, err = wait(t.httpClient, logFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout)
+ _, err = t.deployServiceWait(logFunc, requestFunc, timeout)
return err
}
+func (t *cloudTarget) deployServiceWait(fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) {
+ deployService, err := t.Service(DeployService, 0, 0, "")
+ if err != nil {
+ return 0, err
+ }
+ return wait(deployService, fn, reqFn, timeout)
+}
+
func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error {
if runID > 0 {
if err := t.waitForRun(runID, timeout); err != nil {
@@ -302,13 +269,10 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error {
q := req.URL.Query()
q.Set("after", strconv.FormatInt(lastID, 10))
req.URL.RawQuery = q.Encode()
- if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
- util.JustExitWith(err)
- }
return req
}
jobSuccessFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isCloudOK(status); !ok {
+ if ok, err := isOK(status); !ok {
return ok, err
}
var resp jobResponse
@@ -326,7 +290,7 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error {
}
return true, nil
}
- _, err = wait(t.httpClient, jobSuccessFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout)
+ _, err = t.deployServiceWait(jobSuccessFunc, requestFunc, timeout)
return err
}
@@ -361,12 +325,9 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
if err != nil {
return err
}
- if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil {
- return err
- }
urlsByCluster := make(map[string]string)
endpointFunc := func(status int, response []byte) (bool, error) {
- if ok, err := isCloudOK(status); !ok {
+ if ok, err := isOK(status); !ok {
return ok, err
}
var resp deploymentResponse
@@ -384,7 +345,7 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
}
return true, nil
}
- if _, err = wait(t.httpClient, endpointFunc, func() *http.Request { return req }, t.apiOptions.TLSOptions.KeyPair, timeout); err != nil {
+ if _, err := t.deployServiceWait(endpointFunc, func() *http.Request { return req }, timeout); err != nil {
return err
}
if len(urlsByCluster) == 0 {
@@ -393,11 +354,3 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error {
t.deploymentOptions.ClusterURLs = urlsByCluster
return nil
}
-
-func isCloudOK(status int) (bool, error) {
- if status == 401 {
- // when retrying we should give up immediately if we're not authorized
- return false, fmt.Errorf("status %d: invalid credentials", status)
- }
- return isOK(status), nil
-}
diff --git a/client/go/internal/vespa/target_custom.go b/client/go/internal/vespa/target_custom.go
index 848d19f0a90..0a3a9d48fed 100644
--- a/client/go/internal/vespa/target_custom.go
+++ b/client/go/internal/vespa/target_custom.go
@@ -15,6 +15,7 @@ type customTarget struct {
targetType string
baseURL string
httpClient util.HTTPClient
+ tlsOptions TLSOptions
}
type serviceConvergeResponse struct {
@@ -22,13 +23,13 @@ type serviceConvergeResponse struct {
}
// LocalTarget creates a target for a Vespa platform running locally.
-func LocalTarget(httpClient util.HTTPClient) Target {
- return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1", httpClient: httpClient}
+func LocalTarget(httpClient util.HTTPClient, tlsOptions TLSOptions) Target {
+ return &customTarget{targetType: TargetLocal, baseURL: "http://127.0.0.1", httpClient: httpClient, tlsOptions: tlsOptions}
}
// CustomTarget creates a Target for a Vespa platform running at baseURL.
-func CustomTarget(httpClient util.HTTPClient, baseURL string) Target {
- return &customTarget{targetType: TargetCustom, baseURL: baseURL, httpClient: httpClient}
+func CustomTarget(httpClient util.HTTPClient, baseURL string, tlsOptions TLSOptions) Target {
+ return &customTarget{targetType: TargetCustom, baseURL: baseURL, httpClient: httpClient, tlsOptions: tlsOptions}
}
func (t *customTarget) Type() string { return t.targetType }
@@ -44,7 +45,7 @@ func (t *customTarget) createService(name string) (*Service, error) {
if err != nil {
return nil, err
}
- return &Service{BaseURL: url, Name: name, httpClient: t.httpClient}, nil
+ return &Service{BaseURL: url, Name: name, httpClient: t.httpClient, TLSOptions: t.tlsOptions}, nil
}
return nil, fmt.Errorf("unknown service: %s", name)
}
@@ -60,7 +61,7 @@ func (t *customTarget) Service(name string, timeout time.Duration, sessionOrRunI
if err != nil {
return nil, err
}
- if !isOK(status) {
+ if ok, _ := isOK(status); !ok {
return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL)
}
} else {
@@ -76,8 +77,6 @@ func (t *customTarget) PrintLog(options LogOptions) error {
return fmt.Errorf("log access is only supported on cloud: run vespa-logfmt on the admin node instead")
}
-func (t *customTarget) SignRequest(req *http.Request, sigKeyId string) error { return nil }
-
func (t *customTarget) CheckVersion(version version.Version) error { return nil }
func (t *customTarget) urlWithPort(serviceName string) (string, error) {
@@ -101,19 +100,19 @@ func (t *customTarget) urlWithPort(serviceName string) (string, error) {
}
func (t *customTarget) waitForConvergence(timeout time.Duration) error {
- deployURL, err := t.urlWithPort(DeployService)
+ deployService, err := t.createService(DeployService)
if err != nil {
return err
}
- url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployURL)
+ url := fmt.Sprintf("%s/application/v2/tenant/default/application/default/environment/prod/region/default/instance/default/serviceconverge", deployService.BaseURL)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
converged := false
convergedFunc := func(status int, response []byte) (bool, error) {
- if !isOK(status) {
- return false, nil
+ if ok, err := isOK(status); !ok {
+ return ok, err
}
var resp serviceConvergeResponse
if err := json.Unmarshal(response, &resp); err != nil {
@@ -122,7 +121,7 @@ func (t *customTarget) waitForConvergence(timeout time.Duration) error {
converged = resp.Converged
return converged, nil
}
- if _, err := wait(t.httpClient, convergedFunc, func() *http.Request { return req }, nil, timeout); err != nil {
+ if _, err := wait(deployService, convergedFunc, func() *http.Request { return req }, timeout); err != nil {
return err
}
if !converged {
diff --git a/client/go/internal/vespa/target_test.go b/client/go/internal/vespa/target_test.go
index b9d65f3d8a4..bf266e8f9ec 100644
--- a/client/go/internal/vespa/target_test.go
+++ b/client/go/internal/vespa/target_test.go
@@ -3,7 +3,6 @@ package vespa
import (
"bytes"
- "crypto/tls"
"fmt"
"io"
"net/http"
@@ -19,10 +18,16 @@ import (
type mockVespaApi struct {
deploymentConverged bool
+ authFailure bool
serverURL string
}
func (v *mockVespaApi) mockVespaHandler(w http.ResponseWriter, req *http.Request) {
+ if v.authFailure {
+ response := `{"message":"unauthorized"}`
+ w.WriteHeader(401)
+ w.Write([]byte(response))
+ }
switch req.URL.Path {
case "/cli/v1/":
response := `{"minVersion":"8.0.0"}`
@@ -65,17 +70,17 @@ func (v *mockVespaApi) mockVespaHandler(w http.ResponseWriter, req *http.Request
}
func TestCustomTarget(t *testing.T) {
- lt := LocalTarget(&mock.HTTPClient{})
+ lt := LocalTarget(&mock.HTTPClient{}, TLSOptions{})
assertServiceURL(t, "http://127.0.0.1:19071", lt, "deploy")
assertServiceURL(t, "http://127.0.0.1:8080", lt, "query")
assertServiceURL(t, "http://127.0.0.1:8080", lt, "document")
- ct := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42")
+ ct := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42", TLSOptions{})
assertServiceURL(t, "http://192.0.2.42:19071", ct, "deploy")
assertServiceURL(t, "http://192.0.2.42:8080", ct, "query")
assertServiceURL(t, "http://192.0.2.42:8080", ct, "document")
- ct2 := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42:60000")
+ ct2 := CustomTarget(&mock.HTTPClient{}, "http://192.0.2.42:60000", TLSOptions{})
assertServiceURL(t, "http://192.0.2.42:60000", ct2, "deploy")
assertServiceURL(t, "http://192.0.2.42:60000", ct2, "query")
assertServiceURL(t, "http://192.0.2.42:60000", ct2, "document")
@@ -85,7 +90,7 @@ func TestCustomTargetWait(t *testing.T) {
vc := mockVespaApi{}
srv := httptest.NewServer(http.HandlerFunc(vc.mockVespaHandler))
defer srv.Close()
- target := CustomTarget(util.CreateClient(time.Second*10), srv.URL)
+ target := CustomTarget(util.CreateClient(time.Second*10), srv.URL, TLSOptions{})
_, err := target.Service("query", time.Millisecond, 42, "")
assert.NotNil(t, err)
@@ -107,6 +112,9 @@ func TestCloudTargetWait(t *testing.T) {
var logWriter bytes.Buffer
target := createCloudTarget(t, srv.URL, &logWriter)
+ vc.authFailure = true
+ assertServiceWaitErr(t, 401, true, target, "deploy")
+ vc.authFailure = false
assertServiceWait(t, 200, target, "deploy")
_, err := target.Service("query", time.Millisecond, 42, "")
@@ -157,10 +165,11 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target {
apiKey, err := CreateAPIKey()
assert.Nil(t, err)
+ auth := &mockAuthenticator{}
target, err := CloudTarget(
util.CreateClient(time.Second*10),
- &mockZTS{},
- &mockAuth0{},
+ auth,
+ auth,
APIOptions{APIKey: apiKey, System: PublicSystem},
CloudDeploymentOptions{
Deployment: Deployment{
@@ -175,7 +184,6 @@ func createCloudTarget(t *testing.T, url string, logWriter io.Writer) Target {
}
if ct, ok := target.(*cloudTarget); ok {
ct.apiOptions.System.URL = url
- ct.zts = &mockZTS{token: "foo bar"}
} else {
t.Fatalf("Wrong target type %T", ct)
}
@@ -189,22 +197,22 @@ func assertServiceURL(t *testing.T, url string, target Target, service string) {
}
func assertServiceWait(t *testing.T, expectedStatus int, target Target, service string) {
+ assertServiceWaitErr(t, expectedStatus, false, target, service)
+}
+
+func assertServiceWaitErr(t *testing.T, expectedStatus int, expectErr bool, target Target, service string) {
s, err := target.Service(service, 0, 42, "")
assert.Nil(t, err)
status, err := s.Wait(0)
- assert.Nil(t, err)
+ if expectErr {
+ assert.NotNil(t, err)
+ } else {
+ assert.Nil(t, err)
+ }
assert.Equal(t, expectedStatus, status)
}
-type mockZTS struct{ token string }
-
-func (c *mockZTS) AccessToken(domain string, certificate tls.Certificate) (string, error) {
- return c.token, nil
-}
-
-type mockAuth0 struct{}
-
-func (a *mockAuth0) AccessToken() (string, error) { return "", nil }
+type mockAuthenticator struct{}
-func (a *mockAuth0) HasCredentials() bool { return true }
+func (a *mockAuthenticator) Authenticate(request *http.Request) error { return nil }