diff options
Diffstat (limited to 'client/go/vespa/target.go')
-rw-r--r-- | client/go/vespa/target.go | 79 |
1 files changed, 59 insertions, 20 deletions
diff --git a/client/go/vespa/target.go b/client/go/vespa/target.go index 8a09440f5cc..f497bf5b3cd 100644 --- a/client/go/vespa/target.go +++ b/client/go/vespa/target.go @@ -6,13 +6,16 @@ import ( "crypto/tls" "encoding/json" "fmt" + "github.com/vespa-engine/vespa/client/go/cli" "io" "io/ioutil" "math" "net/http" "net/url" + "os" "sort" "strconv" + "strings" "time" "github.com/vespa-engine/vespa/client/go/util" @@ -35,6 +38,7 @@ type Service struct { BaseURL string Name string TLSOptions TLSOptions + Target *Target } // Target represents a Vespa platform, running named Vespa services. @@ -47,6 +51,8 @@ type Target interface { // PrintLog writes the logs of this deployment using given options to control output. PrintLog(options LogOptions) error + + PrepareApiRequest(req *http.Request, sigKeyId string) error } // TLSOptions configures the certificate to use for service requests. @@ -66,11 +72,21 @@ type LogOptions struct { Level int } +func Auth0AccessTokenEnabled() bool { + v, present := os.LookupEnv("VESPA_CLI_OAUTH2_DEVICE_FLOW") + if !present { + return false + } + return strings.ToLower(v) == "true" || v == "1" || v == "" +} + type customTarget struct { targetType string baseURL string } +func (t *customTarget) PrepareApiRequest(req *http.Request, sigKeyId string) error { return nil } + // Do sends request to this service. Any required authentication happens automatically. func (s *Service) Do(request *http.Request, timeout time.Duration) (*http.Response, error) { if s.TLSOptions.KeyPair.Certificate != nil { @@ -192,8 +208,9 @@ type cloudTarget struct { tlsOptions TLSOptions logOptions LogOptions - queryURL string - documentURL string + queryURL string + documentURL string + authConfigPath string } func (t *cloudTarget) Type() string { return t.targetType } @@ -221,6 +238,30 @@ func (t *cloudTarget) Service(name string, timeout time.Duration, runID int64) ( return nil, fmt.Errorf("unknown service: %s", name) } +func (t *cloudTarget) PrepareApiRequest(req *http.Request, sigKeyId string) error { + if Auth0AccessTokenEnabled() { + if err := t.addAuth0AccessToken(req); err != nil { + return err + } + } else if t.apiKey != nil { + signer := NewRequestSigner(sigKeyId, t.apiKey) + if err := signer.SignRequest(req); err != nil { + return err + } + } + return nil +} + +func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { + c, err := cli.GetCli(t.authConfigPath) + tenant, err := c.PrepareTenant(cli.ContextWithCancel()) + if err != nil { + return err + } + request.Header.Set("Authorization", "Bearer "+tenant.AccessToken) + return nil +} + func (t *cloudTarget) logsURL() string { return fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s/logs", t.apiURL, @@ -233,7 +274,6 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { if err != nil { return err } - signer := NewRequestSigner(t.deployment.Application.SerializedForm(), t.apiKey) lastFrom := options.From requestFunc := func() *http.Request { fromMillis := lastFrom.Unix() * 1000 @@ -244,9 +284,7 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { q.Set("to", strconv.FormatInt(toMillis, 10)) } req.URL.RawQuery = q.Encode() - if err := signer.SignRequest(req); err != nil { - panic(err) - } + t.PrepareApiRequest(req, t.deployment.Application.SerializedForm()) return req } logFunc := func(status int, response []byte) (bool, error) { @@ -280,16 +318,15 @@ func (t *cloudTarget) PrintLog(options LogOptions) error { } func (t *cloudTarget) waitForEndpoints(timeout time.Duration, runID int64) error { - signer := NewRequestSigner(t.deployment.Application.SerializedForm(), t.apiKey) if runID > 0 { - if err := t.waitForRun(signer, runID, timeout); err != nil { + if err := t.waitForRun(runID, timeout); err != nil { return err } } - return t.discoverEndpoints(signer, timeout) + return t.discoverEndpoints(timeout) } -func (t *cloudTarget) waitForRun(signer *RequestSigner, runID int64, timeout time.Duration) error { +func (t *cloudTarget) waitForRun(runID int64, timeout time.Duration) error { runURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/job/%s-%s/run/%d", t.apiURL, t.deployment.Application.Tenant, t.deployment.Application.Application, t.deployment.Application.Instance, @@ -303,7 +340,7 @@ func (t *cloudTarget) waitForRun(signer *RequestSigner, runID int64, timeout tim q := req.URL.Query() q.Set("after", strconv.FormatInt(lastID, 10)) req.URL.RawQuery = q.Encode() - if err := signer.SignRequest(req); err != nil { + if err := t.PrepareApiRequest(req, t.deployment.Application.SerializedForm()); err != nil { panic(err) } return req @@ -353,7 +390,7 @@ func (t *cloudTarget) printLog(response jobResponse, last int64) int64 { return response.LastID } -func (t *cloudTarget) discoverEndpoints(signer *RequestSigner, timeout time.Duration) error { +func (t *cloudTarget) discoverEndpoints(timeout time.Duration) error { deploymentURL := fmt.Sprintf("%s/application/v4/tenant/%s/application/%s/instance/%s/environment/%s/region/%s", t.apiURL, t.deployment.Application.Tenant, t.deployment.Application.Application, t.deployment.Application.Instance, @@ -362,7 +399,7 @@ func (t *cloudTarget) discoverEndpoints(signer *RequestSigner, timeout time.Dura if err != nil { return err } - if err := signer.SignRequest(req); err != nil { + if err := t.PrepareApiRequest(req, t.deployment.Application.SerializedForm()); err != nil { return err } var endpointURL string @@ -409,14 +446,16 @@ func CustomTarget(baseURL string) Target { } // CloudTarget creates a Target for the Vespa Cloud platform. -func CloudTarget(apiURL string, deployment Deployment, apiKey []byte, tlsOptions TLSOptions, logOptions LogOptions) Target { +func CloudTarget(apiURL string, deployment Deployment, apiKey []byte, tlsOptions TLSOptions, logOptions LogOptions, + authConfigPath string) Target { return &cloudTarget{ - apiURL: apiURL, - targetType: cloudTargetType, - deployment: deployment, - apiKey: apiKey, - tlsOptions: tlsOptions, - logOptions: logOptions, + apiURL: apiURL, + targetType: cloudTargetType, + deployment: deployment, + apiKey: apiKey, + tlsOptions: tlsOptions, + logOptions: logOptions, + authConfigPath: authConfigPath, } } |