aboutsummaryrefslogtreecommitdiffstats
path: root/client/go/auth/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'client/go/auth/auth.go')
-rw-r--r--client/go/auth/auth.go175
1 files changed, 175 insertions, 0 deletions
diff --git a/client/go/auth/auth.go b/client/go/auth/auth.go
new file mode 100644
index 00000000000..397e410924d
--- /dev/null
+++ b/client/go/auth/auth.go
@@ -0,0 +1,175 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package auth
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+)
+
+const (
+ audiencePath = "/api/v2/"
+ waitThresholdInSeconds = 3
+ // SecretsNamespace namespace used to set/get values from the keychain
+ SecretsNamespace = "vespa-cli"
+)
+
+var requiredScopes = []string{"openid", "offline_access"}
+
+type Authenticator struct {
+ Audience string
+ ClientID string
+ DeviceCodeEndpoint string
+ OauthTokenEndpoint string
+}
+
+// SecretStore provides access to stored sensitive data.
+type SecretStore interface {
+ // Get gets the secret
+ Get(namespace, key string) (string, error)
+ // Delete removes the secret
+ Delete(namespace, key string) error
+}
+
+type Result struct {
+ Tenant string
+ Domain string
+ RefreshToken string
+ AccessToken string
+ ExpiresIn int64
+}
+
+type State struct {
+ DeviceCode string `json:"device_code"`
+ UserCode string `json:"user_code"`
+ VerificationURI string `json:"verification_uri_complete"`
+ ExpiresIn int `json:"expires_in"`
+ Interval int `json:"interval"`
+}
+
+// RequiredScopes returns the scopes used for login.
+func RequiredScopes() []string { return requiredScopes }
+
+func (s *State) IntervalDuration() time.Duration {
+ return time.Duration(s.Interval+waitThresholdInSeconds) * time.Second
+}
+
+// Start kicks-off the device authentication flow
+// by requesting a device code from Auth0,
+// The returned state contains the URI for the next step of the flow.
+func (a *Authenticator) Start(ctx context.Context) (State, error) {
+ s, err := a.getDeviceCode(ctx)
+ if err != nil {
+ return State{}, fmt.Errorf("cannot get device code: %w", err)
+ }
+ return s, nil
+}
+
+// Wait waits until the user is logged in on the browser.
+func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {
+ t := time.NewTicker(state.IntervalDuration())
+ for {
+ select {
+ case <-ctx.Done():
+ return Result{}, ctx.Err()
+ case <-t.C:
+ data := url.Values{
+ "client_id": {a.ClientID},
+ "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
+ "device_code": {state.DeviceCode},
+ }
+ r, err := http.PostForm(a.OauthTokenEndpoint, data)
+ if err != nil {
+ return Result{}, fmt.Errorf("cannot get device code: %w", err)
+ }
+ defer r.Body.Close()
+
+ var res struct {
+ AccessToken string `json:"access_token"`
+ IDToken string `json:"id_token"`
+ RefreshToken string `json:"refresh_token"`
+ Scope string `json:"scope"`
+ ExpiresIn int64 `json:"expires_in"`
+ TokenType string `json:"token_type"`
+ Error *string `json:"error,omitempty"`
+ ErrorDescription string `json:"error_description,omitempty"`
+ }
+
+ err = json.NewDecoder(r.Body).Decode(&res)
+ if err != nil {
+ return Result{}, fmt.Errorf("cannot decode response: %w", err)
+ }
+
+ if res.Error != nil {
+ if *res.Error == "authorization_pending" {
+ continue
+ }
+ return Result{}, errors.New(res.ErrorDescription)
+ }
+
+ ten, domain, err := parseTenant(res.AccessToken)
+ if err != nil {
+ return Result{}, fmt.Errorf("cannot parse tenant from the given access token: %w", err)
+ }
+
+ return Result{
+ RefreshToken: res.RefreshToken,
+ AccessToken: res.AccessToken,
+ ExpiresIn: res.ExpiresIn,
+ Tenant: ten,
+ Domain: domain,
+ }, nil
+ }
+ }
+}
+
+func (a *Authenticator) getDeviceCode(ctx context.Context) (State, error) {
+ data := url.Values{
+ "client_id": {a.ClientID},
+ "scope": {strings.Join(requiredScopes, " ")},
+ "audience": {a.Audience},
+ }
+ r, err := http.PostForm(a.DeviceCodeEndpoint, data)
+ if err != nil {
+ return State{}, fmt.Errorf("cannot get device code: %w", err)
+ }
+ defer r.Body.Close()
+ var res State
+ err = json.NewDecoder(r.Body).Decode(&res)
+ if err != nil {
+ return State{}, fmt.Errorf("cannot decode response: %w", err)
+ }
+ return res, nil
+}
+
+func parseTenant(accessToken string) (tenant, domain string, err error) {
+ parts := strings.Split(accessToken, ".")
+ v, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ return "", "", err
+ }
+ var payload struct {
+ AUDs []string `json:"aud"`
+ }
+ if err := json.Unmarshal(v, &payload); err != nil {
+ return "", "", err
+ }
+ for _, aud := range payload.AUDs {
+ u, err := url.Parse(aud)
+ if err != nil {
+ return "", "", err
+ }
+ if u.Path == audiencePath {
+ parts := strings.Split(u.Host, ".")
+ return parts[0], u.Host, nil
+ }
+ }
+ return "", "", fmt.Errorf("audience not found for %s", audiencePath)
+}