diff options
Diffstat (limited to 'client/go/auth')
-rw-r--r-- | client/go/auth/auth0/auth0.go | 433 | ||||
-rw-r--r-- | client/go/auth/zts/zts.go | 58 | ||||
-rw-r--r-- | client/go/auth/zts/zts_test.go | 30 |
3 files changed, 521 insertions, 0 deletions
diff --git a/client/go/auth/auth0/auth0.go b/client/go/auth/auth0/auth0.go new file mode 100644 index 00000000000..52ba3f085a4 --- /dev/null +++ b/client/go/auth/auth0/auth0.go @@ -0,0 +1,433 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package auth0 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "os/signal" + "path/filepath" + "sort" + "sync" + "time" + + "github.com/lestrrat-go/jwx/jwt" + "github.com/pkg/browser" + "github.com/vespa-engine/vespa/client/go/auth" + "github.com/vespa-engine/vespa/client/go/util" +) + +const accessTokenExpThreshold = 5 * time.Minute + +var errUnauthenticated = errors.New("not logged in. Try 'vespa auth login'") + +type configJsonFormat struct { + Version int `json:"version"` + Providers providers `json:"providers"` +} + +type providers struct { + Config config `json:"auth0"` +} + +type config struct { + Version int `json:"version"` + Systems map[string]*System `json:"systems"` +} + +type System struct { + Name string `json:"-"` + AccessToken string `json:"access_token,omitempty"` + Scopes []string `json:"scopes,omitempty"` + ExpiresAt time.Time `json:"expires_at"` +} + +type Auth0 struct { + Authenticator *auth.Authenticator + system string + systemApiUrl string + initOnce sync.Once + errOnce error + Path string + config config +} + +type authCfg struct { + Audience string `json:"audience"` + ClientID string `json:"client-id"` + DeviceCodeEndpoint string `json:"device-code-endpoint"` + OauthTokenEndpoint string `json:"oauth-token-endpoint"` +} + +func ContextWithCancel() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + go func() { + <-ch + defer cancel() + os.Exit(0) + }() + return ctx +} + +// GetAuth0 will try to initialize the config context, as well as figure out if +// there's a readily available system. +func GetAuth0(configPath string, systemName string, systemApiUrl string) (*Auth0, error) { + a := Auth0{} + a.Path = configPath + a.system = systemName + a.systemApiUrl = systemApiUrl + c, err := a.getDeviceFlowConfig() + if err != nil { + return nil, fmt.Errorf("cannot get auth config: %w", err) + } + a.Authenticator = &auth.Authenticator{ + Audience: c.Audience, + ClientID: c.ClientID, + DeviceCodeEndpoint: c.DeviceCodeEndpoint, + OauthTokenEndpoint: c.OauthTokenEndpoint, + } + return &a, nil +} + +func (a *Auth0) getDeviceFlowConfig() (authCfg, error) { + systemApiUrl, _ := url.Parse(a.systemApiUrl + "/auth0/v1/device-flow-config") + r, err := http.Get(systemApiUrl.String()) + if err != nil { + return authCfg{}, fmt.Errorf("cannot get auth config: %w", err) + } + defer r.Body.Close() + var res authCfg + err = json.NewDecoder(r.Body).Decode(&res) + if err != nil { + return authCfg{}, fmt.Errorf("cannot decode response: %w", err) + } + return res, nil +} + +// IsLoggedIn encodes the domain logic for determining whether we're +// logged in. This might check our config storage, or just in memory. +func (a *Auth0) IsLoggedIn() bool { + // No need to check errors for initializing context. + _ = a.init() + + if a.system == "" { + return false + } + + // Parse the access token for the system. + token, err := jwt.ParseString(a.config.Systems[a.system].AccessToken) + if err != nil { + return false + } + + // Check if token is valid. + // TODO: Choose issuer based on system + if err = jwt.Validate(token, jwt.WithIssuer("https://vespa-cd.auth0.com/")); err != nil { + return false + } + + return true +} + +// PrepareSystem loads the System, refreshing its token if necessary. +// The System access token needs a refresh if: +// 1. the System scopes are different from the currently required scopes - (auth0 changes). +// 2. the access token is expired. +func (a *Auth0) PrepareSystem(ctx context.Context) (*System, error) { + if err := a.init(); err != nil { + return nil, err + } + s, err := a.getSystem() + if err != nil { + return nil, err + } + + if s.AccessToken == "" || scopesChanged(s) { + s, err = RunLogin(ctx, a, true) + if err != nil { + return nil, err + } + } else if isExpired(s.ExpiresAt, accessTokenExpThreshold) { + // check if the stored access token is expired: + // use the refresh token to get a new access token: + tr := &auth.TokenRetriever{ + Authenticator: a.Authenticator, + Secrets: &auth.Keyring{}, + Client: http.DefaultClient, + } + + res, err := tr.Refresh(ctx, a.system) + if err != nil { + return nil, fmt.Errorf("failed to renew access token: %w: %s", err, "re-authenticate with 'vespa auth login'") + } else { + // persist the updated system with renewed access token + s.AccessToken = res.AccessToken + s.ExpiresAt = time.Now().Add( + time.Duration(res.ExpiresIn) * time.Second, + ) + + err = a.AddSystem(s) + if err != nil { + return nil, err + } + } + } + + return s, nil +} + +// isExpired is true if now() + a threshold is after the given date +func isExpired(t time.Time, threshold time.Duration) bool { + return time.Now().Add(threshold).After(t) +} + +// scopesChanged compare the System scopes +// with the currently required scopes. +func scopesChanged(s *System) bool { + want := auth.RequiredScopes() + got := s.Scopes + + sort.Strings(want) + sort.Strings(got) + + if (want == nil) != (got == nil) { + return true + } + + if len(want) != len(got) { + return true + } + + for i := range s.Scopes { + if want[i] != got[i] { + return true + } + } + + return false +} + +func (a *Auth0) getSystem() (*System, error) { + if err := a.init(); err != nil { + return nil, err + } + + s, ok := a.config.Systems[a.system] + if !ok { + return nil, fmt.Errorf("unable to find system: %s; run 'vespa auth login' to configure a new system", a.system) + } + + return s, nil +} + +// HasSystem checks if the system is configured +// TODO: Used to print deprecation warning if we fall back to use tenant API key. +// Remove when this is not longer needed. +func (a *Auth0) HasSystem() bool { + if _, err := a.getSystem(); err != nil { + return false + } + return true +} + +// AddSystem assigns an existing, or new System. This is expected to be called +// after a login has completed. +func (a *Auth0) AddSystem(s *System) error { + _ = a.init() + + // If we're dealing with an empty file, we'll need to initialize this map. + if a.config.Systems == nil { + a.config.Systems = map[string]*System{} + } + + a.config.Systems[a.system] = s + + if err := a.persistConfig(); err != nil { + return fmt.Errorf("unexpected error persisting config: %w", err) + } + + return nil +} + +func (a *Auth0) removeSystem(s string) error { + _ = a.init() + + // If we're dealing with an empty file, we'll need to initialize this map. + if a.config.Systems == nil { + a.config.Systems = map[string]*System{} + } + + delete(a.config.Systems, s) + + if err := a.persistConfig(); err != nil { + return fmt.Errorf("unexpected error persisting config: %w", err) + } + + tr := &auth.TokenRetriever{Secrets: &auth.Keyring{}} + if err := tr.Delete(s); err != nil { + return fmt.Errorf("unexpected error clearing system information: %w", err) + } + + return nil +} + +func (a *Auth0) persistConfig() error { + dir := filepath.Dir(a.Path) + if _, err := os.Stat(dir); os.IsNotExist(err) { + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + } + + buf, err := a.configToJson(&a.config) + if err != nil { + return err + } + + if err := ioutil.WriteFile(a.Path, buf, 0600); err != nil { + return err + } + + return nil +} + +func (a *Auth0) configToJson(cfg *config) ([]byte, error) { + cfg.Version = 1 + r := configJsonFormat{ + Version: 1, + Providers: providers{ + Config: *cfg, + }, + } + return json.MarshalIndent(r, "", " ") +} + +func (a *Auth0) jsonToConfig(buf []byte) (*config, error) { + r := configJsonFormat{} + if err := json.Unmarshal(buf, &r); err != nil { + return nil, err + } + cfg := r.Providers.Config + if cfg.Systems != nil { + for n, s := range cfg.Systems { + s.Name = n + } + } + return &cfg, nil +} + +func (a *Auth0) init() error { + a.initOnce.Do(func() { + if a.errOnce = a.initContext(); a.errOnce != nil { + return + } + }) + return a.errOnce +} + +func (a *Auth0) initContext() (err error) { + if _, err := os.Stat(a.Path); os.IsNotExist(err) { + return errUnauthenticated + } + + var buf []byte + if buf, err = ioutil.ReadFile(a.Path); err != nil { + return err + } + + cfg, err := a.jsonToConfig(buf) + if err != nil { + return err + } + a.config = *cfg + return nil +} + +// RunLogin runs the login flow guiding the user through the process +// by showing the login instructions, opening the browser. +// Use `expired` to run the login from other commands setup: +// this will only affect the messages. +func RunLogin(ctx context.Context, a *Auth0, expired bool) (*System, error) { + if expired { + fmt.Println("Please sign in to re-authorize the CLI.") + } + + state, err := a.Authenticator.Start(ctx) + if err != nil { + return nil, fmt.Errorf("could not start the authentication process: %w", err) + } + + fmt.Printf("Your Device Confirmation code is: %s\n\n", state.UserCode) + + fmt.Println("If you prefer, you can open the URL directly for verification") + fmt.Printf("Your Verification URL: %s\n\n", state.VerificationURI) + + fmt.Println("Press Enter to open the browser to log in or ^C to quit...") + fmt.Scanln() + + err = browser.OpenURL(state.VerificationURI) + + if err != nil { + fmt.Printf("Couldn't open the URL, please do it manually: %s.", state.VerificationURI) + } + + var res auth.Result + err = util.Spinner(os.Stderr, "Waiting for login to complete in browser ...", func() error { + res, err = a.Authenticator.Wait(ctx, state) + return err + }) + + if err != nil { + return nil, fmt.Errorf("login error: %w", err) + } + + fmt.Print("\n") + fmt.Println("Successfully logged in.") + fmt.Print("\n") + + // store the refresh token + secretsStore := &auth.Keyring{} + err = secretsStore.Set(auth.SecretsNamespace, a.system, res.RefreshToken) + if err != nil { + // log the error but move on + fmt.Println("Could not store the refresh token locally, please expect to login again once your access token expired.") + } + + s := System{ + Name: a.system, + AccessToken: res.AccessToken, + ExpiresAt: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), + Scopes: auth.RequiredScopes(), + } + err = a.AddSystem(&s) + if err != nil { + return nil, fmt.Errorf("could not add system to config: %w", err) + } + + return &s, nil +} + +func RunLogout(a *Auth0) error { + s, err := a.getSystem() + if err != nil { + return err + } + + if err := a.removeSystem(s.Name); err != nil { + return err + } + + fmt.Print("\n") + fmt.Println("Successfully logged out.") + fmt.Print("\n") + + return nil +} diff --git a/client/go/auth/zts/zts.go b/client/go/auth/zts/zts.go new file mode 100644 index 00000000000..d288c2050d9 --- /dev/null +++ b/client/go/auth/zts/zts.go @@ -0,0 +1,58 @@ +package zts + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/vespa-engine/vespa/client/go/util" +) + +const DefaultURL = "https://zts.athenz.ouroath.com:4443" + +// Client is a client for Athenz ZTS, an authentication token service. +type Client struct { + client util.HTTPClient + tokenURL *url.URL +} + +// NewClient creates a new client for an Athenz ZTS service located at serviceURL. +func NewClient(serviceURL string, client util.HTTPClient) (*Client, error) { + tokenURL, err := url.Parse(serviceURL) + if err != nil { + return nil, err + } + tokenURL.Path = "/zts/v1/oauth2/token" + return &Client{tokenURL: tokenURL, client: client}, nil +} + +// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS. +func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) { + data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain) + req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data)) + if err != nil { + return "", err + } + c.client.UseCertificate([]tls.Certificate{certificate}) + response, err := c.client.Do(req, 10*time.Second) + if err != nil { + return "", err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String()) + } + var ztsResponse struct { + AccessToken string `json:"access_token"` + } + dec := json.NewDecoder(response.Body) + if err := dec.Decode(&ztsResponse); err != nil { + return "", err + } + return ztsResponse.AccessToken, nil +} diff --git a/client/go/auth/zts/zts_test.go b/client/go/auth/zts/zts_test.go new file mode 100644 index 00000000000..0eec085aadb --- /dev/null +++ b/client/go/auth/zts/zts_test.go @@ -0,0 +1,30 @@ +package zts + +import ( + "crypto/tls" + "testing" + + "github.com/vespa-engine/vespa/client/go/mock" +) + +func TestAccessToken(t *testing.T) { + httpClient := mock.HTTPClient{} + client, err := NewClient("http://example.com", &httpClient) + if err != nil { + t.Fatal(err) + } + httpClient.NextResponse(400, `{"message": "bad request"}`) + _, err = client.AccessToken("vespa.vespa", tls.Certificate{}) + if err == nil { + t.Fatal("want error for non-ok response status") + } + httpClient.NextResponse(200, `{"access_token": "foo bar"}`) + token, err := client.AccessToken("vespa.vespa", tls.Certificate{}) + if err != nil { + t.Fatal(err) + } + want := "foo bar" + if token != want { + t.Errorf("got %q, want %q", token, want) + } +} |