diff options
Diffstat (limited to 'client/go/auth/auth.go')
-rw-r--r-- | client/go/auth/auth.go | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/client/go/auth/auth.go b/client/go/auth/auth.go new file mode 100644 index 00000000000..397e410924d --- /dev/null +++ b/client/go/auth/auth.go @@ -0,0 +1,175 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + audiencePath = "/api/v2/" + waitThresholdInSeconds = 3 + // SecretsNamespace namespace used to set/get values from the keychain + SecretsNamespace = "vespa-cli" +) + +var requiredScopes = []string{"openid", "offline_access"} + +type Authenticator struct { + Audience string + ClientID string + DeviceCodeEndpoint string + OauthTokenEndpoint string +} + +// SecretStore provides access to stored sensitive data. +type SecretStore interface { + // Get gets the secret + Get(namespace, key string) (string, error) + // Delete removes the secret + Delete(namespace, key string) error +} + +type Result struct { + Tenant string + Domain string + RefreshToken string + AccessToken string + ExpiresIn int64 +} + +type State struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// RequiredScopes returns the scopes used for login. +func RequiredScopes() []string { return requiredScopes } + +func (s *State) IntervalDuration() time.Duration { + return time.Duration(s.Interval+waitThresholdInSeconds) * time.Second +} + +// Start kicks-off the device authentication flow +// by requesting a device code from Auth0, +// The returned state contains the URI for the next step of the flow. +func (a *Authenticator) Start(ctx context.Context) (State, error) { + s, err := a.getDeviceCode(ctx) + if err != nil { + return State{}, fmt.Errorf("cannot get device code: %w", err) + } + return s, nil +} + +// Wait waits until the user is logged in on the browser. +func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) { + t := time.NewTicker(state.IntervalDuration()) + for { + select { + case <-ctx.Done(): + return Result{}, ctx.Err() + case <-t.C: + data := url.Values{ + "client_id": {a.ClientID}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "device_code": {state.DeviceCode}, + } + r, err := http.PostForm(a.OauthTokenEndpoint, data) + if err != nil { + return Result{}, fmt.Errorf("cannot get device code: %w", err) + } + defer r.Body.Close() + + var res struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` + Error *string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + } + + err = json.NewDecoder(r.Body).Decode(&res) + if err != nil { + return Result{}, fmt.Errorf("cannot decode response: %w", err) + } + + if res.Error != nil { + if *res.Error == "authorization_pending" { + continue + } + return Result{}, errors.New(res.ErrorDescription) + } + + ten, domain, err := parseTenant(res.AccessToken) + if err != nil { + return Result{}, fmt.Errorf("cannot parse tenant from the given access token: %w", err) + } + + return Result{ + RefreshToken: res.RefreshToken, + AccessToken: res.AccessToken, + ExpiresIn: res.ExpiresIn, + Tenant: ten, + Domain: domain, + }, nil + } + } +} + +func (a *Authenticator) getDeviceCode(ctx context.Context) (State, error) { + data := url.Values{ + "client_id": {a.ClientID}, + "scope": {strings.Join(requiredScopes, " ")}, + "audience": {a.Audience}, + } + r, err := http.PostForm(a.DeviceCodeEndpoint, data) + if err != nil { + return State{}, fmt.Errorf("cannot get device code: %w", err) + } + defer r.Body.Close() + var res State + err = json.NewDecoder(r.Body).Decode(&res) + if err != nil { + return State{}, fmt.Errorf("cannot decode response: %w", err) + } + return res, nil +} + +func parseTenant(accessToken string) (tenant, domain string, err error) { + parts := strings.Split(accessToken, ".") + v, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", "", err + } + var payload struct { + AUDs []string `json:"aud"` + } + if err := json.Unmarshal(v, &payload); err != nil { + return "", "", err + } + for _, aud := range payload.AUDs { + u, err := url.Parse(aud) + if err != nil { + return "", "", err + } + if u.Path == audiencePath { + parts := strings.Split(u.Host, ".") + return parts[0], u.Host, nil + } + } + return "", "", fmt.Errorf("audience not found for %s", audiencePath) +} |