diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-04-13 15:21:18 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-04-17 10:31:40 +0200 |
commit | 96d8aae1ec9b4f6130b6b610ce23d2bbdb79298a (patch) | |
tree | f482eaa488eb5d5925b49d665b29c07ab516ef7f /client/go/internal/vespa/target_cloud.go | |
parent | cce3b08cbe1864e80d5b9e57891622706b1d8181 (diff) |
Support TLS in custom target
Diffstat (limited to 'client/go/internal/vespa/target_cloud.go')
-rw-r--r-- | client/go/internal/vespa/target_cloud.go | 93 |
1 files changed, 27 insertions, 66 deletions
diff --git a/client/go/internal/vespa/target_cloud.go b/client/go/internal/vespa/target_cloud.go index 1fb3edd78c5..e9dca55f654 100644 --- a/client/go/internal/vespa/target_cloud.go +++ b/client/go/internal/vespa/target_cloud.go @@ -2,7 +2,6 @@ package vespa import ( "bytes" - "crypto/tls" "encoding/json" "fmt" "math" @@ -35,8 +34,8 @@ type cloudTarget struct { deploymentOptions CloudDeploymentOptions logOptions LogOptions httpClient util.HTTPClient - zts zts - auth0 auth0 + apiAuth Authenticator + deploymentAuth Authenticator } type deploymentEndpoint struct { @@ -62,23 +61,15 @@ type logMessage struct { Message string `json:"message"` } -type zts interface { - AccessToken(domain string, certficiate tls.Certificate) (string, error) -} - -type auth0 interface { - AccessToken() (string, error) -} - // CloudTarget creates a Target for the Vespa Cloud or hosted Vespa platform. -func CloudTarget(httpClient util.HTTPClient, ztsClient zts, auth0Client auth0, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { +func CloudTarget(httpClient util.HTTPClient, apiAuth Authenticator, deploymentAuth Authenticator, apiOptions APIOptions, deploymentOptions CloudDeploymentOptions, logOptions LogOptions) (Target, error) { return &cloudTarget{ httpClient: httpClient, apiOptions: apiOptions, deploymentOptions: deploymentOptions, logOptions: logOptions, - zts: ztsClient, - auth0: auth0Client, + apiAuth: apiAuth, + deploymentAuth: deploymentAuth, }, nil } @@ -118,15 +109,14 @@ func (t *cloudTarget) IsCloud() bool { return true } func (t *cloudTarget) Deployment() Deployment { return t.deploymentOptions.Deployment } func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, cluster string) (*Service, error) { - var service *Service switch name { case DeployService: - service = &Service{ + service := &Service{ Name: name, BaseURL: t.apiOptions.System.URL, TLSOptions: t.apiOptions.TLSOptions, - zts: t.zts, httpClient: t.httpClient, + auth: t.apiAuth, } if timeout > 0 { status, err := service.Wait(timeout) @@ -137,6 +127,7 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c return nil, fmt.Errorf("got status %d from deploy service at %s", status, service.BaseURL) } } + return service, nil case QueryService, DocumentService: if t.deploymentOptions.ClusterURLs == nil { if err := t.waitForEndpoints(timeout, runID); err != nil { @@ -147,38 +138,15 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64, c if err != nil { return nil, err } - t.deploymentOptions.TLSOptions.AthenzDomain = t.apiOptions.System.AthenzDomain - service = &Service{ + return &Service{ Name: name, BaseURL: url, TLSOptions: t.deploymentOptions.TLSOptions, - zts: t.zts, httpClient: t.httpClient, - } - + auth: t.deploymentAuth, + }, nil default: return nil, fmt.Errorf("unknown service: %s", name) - - } - if service.TLSOptions.KeyPair != nil { - util.SetCertificates(service.httpClient, service.TLSOptions.KeyPair) - } - return service, nil -} - -func (t *cloudTarget) SignRequest(req *http.Request, keyID string) error { - if t.apiOptions.System.IsPublic() { - if t.apiOptions.APIKey != nil { - signer := NewRequestSigner(keyID, t.apiOptions.APIKey) - return signer.SignRequest(req) - } else { - return t.addAuth0AccessToken(req) - } - } else { - if t.apiOptions.TLSOptions.KeyPair == nil { - return fmt.Errorf("system %s requires a certificate for authentication", t.apiOptions.System.Name) - } - return nil } } @@ -190,7 +158,11 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { if err != nil { return err } - response, err := t.httpClient.Do(req, 10*time.Second) + deployService, err := t.Service(DeployService, 0, 0, "") + if err != nil { + return err + } + response, err := deployService.Do(req, 10*time.Second) if err != nil { return err } @@ -212,18 +184,6 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { return nil } -func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { - accessToken, err := t.auth0.AccessToken() - if err != nil { - return err - } - if request.Header == nil { - request.Header = make(http.Header) - } - request.Header.Set("Authorization", "Bearer "+accessToken) - return nil -} - func (t *cloudTarget) logsURL() string { return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", t.apiOptions.System.URL, @@ -246,7 +206,6 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { q.Set("to", strconv.FormatInt(toMillis, 10)) } req.URL.RawQuery = q.Encode() - t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()) return req } logFunc := func(status int, response []byte) (bool, error) { @@ -275,10 +234,18 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { if options.Follow { timeout = math.MaxInt64 // No timeout } - _, err = wait(t.httpClient, logFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout) + _, err = t.deployServiceWait(logFunc, requestFunc, timeout) return err } +func (t *cloudTarget) deployServiceWait(fn responseFunc, reqFn requestFunc, timeout time.Duration) (int, error) { + deployService, err := t.Service(DeployService, 0, 0, "") + if err != nil { + return 0, err + } + return wait(deployService, fn, reqFn, timeout) +} + func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { if runID > 0 { if err := t.waitForRun(runID, timeout); err != nil { @@ -302,9 +269,6 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { q := req.URL.Query() q.Set("after", strconv.FormatInt(lastID, 10)) req.URL.RawQuery = q.Encode() - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - util.JustExitWith(err) - } return req } jobSuccessFunc := func(status int, response []byte) (bool, error) { @@ -326,7 +290,7 @@ func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { } return true, nil } - _, err = wait(t.httpClient, jobSuccessFunc, requestFunc, t.apiOptions.TLSOptions.KeyPair, timeout) + _, err = t.deployServiceWait(jobSuccessFunc, requestFunc, timeout) return err } @@ -361,9 +325,6 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { if err != nil { return err } - if err := t.SignRequest(req, t.deploymentOptions.Deployment.Application.SerializedForm()); err != nil { - return err - } urlsByCluster := make(map[string]string) endpointFunc := func(status int, response []byte) (bool, error) { if ok, err := isCloudOK(status); !ok { @@ -384,7 +345,7 @@ func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { } return true, nil } - if _, err = wait(t.httpClient, endpointFunc, func() *http.Request { return req }, t.apiOptions.TLSOptions.KeyPair, timeout); err != nil { + if _, err := t.deployServiceWait(endpointFunc, func() *http.Request { return req }, timeout); err != nil { return err } if len(urlsByCluster) == 0 { |