diff options
author | Martin Polden <mpolden@mpolden.no> | 2022-03-29 14:25:41 +0200 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2022-03-29 14:39:33 +0200 |
commit | 8d85938bfcbb86e6fb720c01f543b0f0e6f07037 (patch) | |
tree | c4faa49276e4b90faf7e154505e0bff6e1b0d941 /client | |
parent | ff9bca389acc5b2c374096b0061eb4ef28e3bdd7 (diff) |
Refactor and test auth0 package
Diffstat (limited to 'client')
-rw-r--r-- | client/go/auth/auth0/auth0.go | 359 | ||||
-rw-r--r-- | client/go/auth/auth0/auth0_test.go | 99 | ||||
-rw-r--r-- | client/go/cmd/config.go | 4 | ||||
-rw-r--r-- | client/go/cmd/curl.go | 26 | ||||
-rw-r--r-- | client/go/cmd/login.go | 10 | ||||
-rw-r--r-- | client/go/cmd/logout.go | 4 | ||||
-rw-r--r-- | client/go/vespa/target_cloud.go | 9 |
7 files changed, 256 insertions, 255 deletions
diff --git a/client/go/auth/auth0/auth0.go b/client/go/auth/auth0/auth0.go index b2749730c1f..6c31b3ba8e4 100644 --- a/client/go/auth/auth0/auth0.go +++ b/client/go/auth/auth0/auth0.go @@ -5,64 +5,64 @@ package auth0 import ( "context" "encoding/json" - "errors" "fmt" "net/http" - "net/url" "os" "os/signal" "path/filepath" "sort" - "sync" "time" - "github.com/lestrrat-go/jwx/jwt" "github.com/vespa-engine/vespa/client/go/auth" + "github.com/vespa-engine/vespa/client/go/util" ) -const accessTokenExpThreshold = 5 * time.Minute - -var errUnauthenticated = errors.New("not logged in. Try 'vespa auth login'") +const ( + accessTokenExpiry = 5 * time.Minute + reauthMessage = "re-authenticate with 'vespa auth login'" +) -type configJsonFormat struct { - Version int `json:"version"` - Providers providers `json:"providers"` +// Credentials holds the credentials retrieved from Auth0. +type Credentials struct { + AccessToken string `json:"access_token,omitempty"` + Scopes []string `json:"scopes,omitempty"` + ExpiresAt time.Time `json:"expires_at"` } -type providers struct { - Config config `json:"auth0"` +// Client is a client for the Auth0 service. +type Client struct { + httpClient util.HTTPClient + Authenticator *auth.Authenticator // TODO: Make this private + configPath string + systemName string + systemURL string + provider auth0Provider } +// config is the root type of the persisted config type config struct { - Version int `json:"version"` - Systems map[string]*System `json:"systems"` + Version int `json:"version"` + Providers providers `json:"providers"` } -type System struct { - Name string `json:"-"` - AccessToken string `json:"access_token,omitempty"` - Scopes []string `json:"scopes,omitempty"` - ExpiresAt time.Time `json:"expires_at"` +type providers struct { + Auth0 auth0Provider `json:"auth0"` } -type Auth0 struct { - Authenticator *auth.Authenticator - system string - systemApiUrl string - initOnce sync.Once - errOnce error - Path string - config config +type auth0Provider struct { + Version int `json:"version"` + Systems map[string]Credentials `json:"systems"` } -type authCfg struct { +// flowConfig represents the authorization flow configuration retrieved from a Vespa system. +type flowConfig struct { Audience string `json:"audience"` ClientID string `json:"client-id"` DeviceCodeEndpoint string `json:"device-code-endpoint"` OauthTokenEndpoint string `json:"oauth-token-endpoint"` } -func ContextWithCancel() context.Context { +func cancelOnInterrupt() context.Context { ctx, cancel := context.WithCancel(context.Background()) ch := make(chan os.Signal, 1) signal.Notify(ch, os.Interrupt) @@ -74,16 +74,15 @@ func ContextWithCancel() context.Context { return ctx } -// GetAuth0 will try to initialize the config context, as well as figure out if -// there's a readily available system. -func GetAuth0(configPath string, systemName string, systemApiUrl string) (*Auth0, error) { - a := Auth0{} - a.Path = configPath - a.system = systemName - a.systemApiUrl = systemApiUrl +func newClient(httpClient util.HTTPClient, configPath, systemName, systemURL string) (*Client, error) { + a := Client{} + a.httpClient = httpClient + a.configPath = configPath + a.systemName = systemName + a.systemURL = systemURL c, err := a.getDeviceFlowConfig() if err != nil { - return nil, fmt.Errorf("cannot get auth config: %w", err) + return nil, err } a.Authenticator = &auth.Authenticator{ Audience: c.Audience, @@ -91,65 +90,51 @@ func GetAuth0(configPath string, systemName string, systemApiUrl string) (*Auth0 DeviceCodeEndpoint: c.DeviceCodeEndpoint, OauthTokenEndpoint: c.OauthTokenEndpoint, } - return &a, nil -} - -func (a *Auth0) getDeviceFlowConfig() (authCfg, error) { - systemApiUrl, _ := url.Parse(a.systemApiUrl + "/auth0/v1/device-flow-config") - r, err := http.Get(systemApiUrl.String()) - if err != nil { - return authCfg{}, fmt.Errorf("cannot get auth config: %w", err) - } - defer r.Body.Close() - var res authCfg - err = json.NewDecoder(r.Body).Decode(&res) + provider, err := readConfig(configPath) if err != nil { - return authCfg{}, fmt.Errorf("cannot decode response: %w", err) + return nil, err } - return res, nil + a.provider = provider + return &a, nil } -// IsLoggedIn encodes the domain logic for determining whether we're -// logged in. This might check our config storage, or just in memory. -func (a *Auth0) IsLoggedIn() bool { - // No need to check errors for initializing context. - _ = a.init() - - if a.system == "" { - return false - } +// New constructs a new Auth0 client, storing configuration in the given configPath. The client will be configured for +// use in the given Vespa system. +func New(configPath string, systemName, systemURL string) (*Client, error) { + return newClient(util.CreateClient(time.Second*30), configPath, systemName, systemURL) +} - // Parse the access token for the system. - token, err := jwt.ParseString(a.config.Systems[a.system].AccessToken) +func (a *Client) getDeviceFlowConfig() (flowConfig, error) { + url := a.systemURL + "/auth0/v1/device-flow-config" + req, err := http.NewRequest("GET", url, nil) if err != nil { - return false + return flowConfig{}, err } - - // Check if token is valid. - // TODO: Choose issuer based on system - if err = jwt.Validate(token, jwt.WithIssuer("https://vespa-cd.auth0.com/")); err != nil { - return false + r, err := a.httpClient.Do(req, time.Second*30) + if err != nil { + return flowConfig{}, fmt.Errorf("failed to get device flow config: %w", err) } - - return true -} - -// PrepareSystem loads the System, refreshing its token if necessary. -// The System access token needs a refresh if the access token has expired. -func (a *Auth0) PrepareSystem(ctx context.Context) (*System, error) { - if err := a.init(); err != nil { - return nil, err + defer r.Body.Close() + if r.StatusCode/100 != 2 { + return flowConfig{}, fmt.Errorf("failed to get device flow config: got response code %d from %s", r.StatusCode, url) } - s, err := a.getSystem() - if err != nil { - return nil, err + var cfg flowConfig + if err := json.NewDecoder(r.Body).Decode(&cfg); err != nil { + return flowConfig{}, fmt.Errorf("failed to decode response: %w", err) } + return cfg, nil +} - if s.AccessToken == "" { - return nil, fmt.Errorf("access token missing: re-authenticate with 'vespa auth login'") - } else if scopesChanged(s) { - return nil, fmt.Errorf("authentication scopes cahnges: re-authenticate with 'vespa auth login'") - } else if isExpired(s.ExpiresAt, accessTokenExpThreshold) { +// GetAccessToken returns an access token for the configured system, refreshing it if necessary. +func (a *Client) GetAccessToken() (string, error) { + creds, ok := a.provider.Systems[a.systemName] + if !ok { + return "", fmt.Errorf("system %s is not configured", a.systemName) + } else if creds.AccessToken == "" { + return "", fmt.Errorf("access token missing: %s", reauthMessage) + } else if scopesChanged(creds) { + return "", fmt.Errorf("authentication scopes changed: %s", reauthMessage) + } else if isExpired(creds.ExpiresAt, accessTokenExpiry) { // check if the stored access token is expired: // use the refresh token to get a new access token: tr := &auth.TokenRetriever{ @@ -157,190 +142,106 @@ func (a *Auth0) PrepareSystem(ctx context.Context) (*System, error) { Secrets: &auth.Keyring{}, Client: http.DefaultClient, } - - res, err := tr.Refresh(ctx, a.system) + resp, err := tr.Refresh(cancelOnInterrupt(), a.systemName) if err != nil { - return nil, fmt.Errorf("failed to renew access token: %w: %s", err, "re-authenticate with 'vespa auth login'") + return "", fmt.Errorf("failed to renew access token: %w: %s", err, reauthMessage) } else { // persist the updated system with renewed access token - s.AccessToken = res.AccessToken - s.ExpiresAt = time.Now().Add( - time.Duration(res.ExpiresIn) * time.Second, - ) - - err = a.AddSystem(s) - if err != nil { - return nil, err + creds.AccessToken = resp.AccessToken + creds.ExpiresAt = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) + if err := a.WriteCredentials(creds); err != nil { + return "", err } } } - - return s, nil -} - -// isExpired is true if now() + a threshold is after the given date -func isExpired(t time.Time, threshold time.Duration) bool { - return time.Now().Add(threshold).After(t) + return creds.AccessToken, nil } -// scopesChanged compare the System scopes -// with the currently required scopes. -func scopesChanged(s *System) bool { - want := auth.RequiredScopes() - got := s.Scopes +func isExpired(t time.Time, ttl time.Duration) bool { return time.Now().Add(ttl).After(t) } - sort.Strings(want) - sort.Strings(got) - - if (want == nil) != (got == nil) { - return true - } - - if len(want) != len(got) { +func scopesChanged(s Credentials) bool { + required := auth.RequiredScopes() + current := s.Scopes + if len(required) != len(current) { return true } - + sort.Strings(required) + sort.Strings(current) for i := range s.Scopes { - if want[i] != got[i] { + if required[i] != current[i] { return true } } - return false } -func (a *Auth0) getSystem() (*System, error) { - if err := a.init(); err != nil { - return nil, err - } - - s, ok := a.config.Systems[a.system] - if !ok { - return nil, fmt.Errorf("unable to find system: %s; run 'vespa auth login' to configure a new system", a.system) - } - - return s, nil +// HasCredentials returns true if this client has retrived credentials for the configured system. +func (a *Client) HasCredentials() bool { + _, ok := a.provider.Systems[a.systemName] + return ok } -// HasSystem checks if the system is configured -// TODO: Used to print deprecation warning if we fall back to use tenant API key. -// Remove when this is not longer needed. -func (a *Auth0) HasSystem() bool { - if _, err := a.getSystem(); err != nil { - return false +// WriteCredentials writes given credentials to the configuration file. +func (a *Client) WriteCredentials(credentials Credentials) error { + if a.provider.Systems == nil { + a.provider.Systems = make(map[string]Credentials) } - return true -} - -// AddSystem assigns an existing, or new System. This is expected to be called -// after a login has completed. -func (a *Auth0) AddSystem(s *System) error { - _ = a.init() - - // If we're dealing with an empty file, we'll need to initialize this map. - if a.config.Systems == nil { - a.config.Systems = map[string]*System{} - } - - a.config.Systems[a.system] = s - - if err := a.persistConfig(); err != nil { - return fmt.Errorf("unexpected error persisting config: %w", err) + a.provider.Systems[a.systemName] = credentials + if err := writeConfig(a.provider, a.configPath); err != nil { + return fmt.Errorf("failed to write config: %w", err) } - return nil } -func (a *Auth0) RemoveSystem(s string) error { - _ = a.init() - - // If we're dealing with an empty file, we'll need to initialize this map. - if a.config.Systems == nil { - a.config.Systems = map[string]*System{} - } - - delete(a.config.Systems, s) - - if err := a.persistConfig(); err != nil { - return fmt.Errorf("unexpected error persisting config: %w", err) - } - +// RemoveCredentials removes credentials for the system configured in this client. +func (a *Client) RemoveCredentials() error { tr := &auth.TokenRetriever{Secrets: &auth.Keyring{}} - if err := tr.Delete(s); err != nil { - return fmt.Errorf("unexpected error clearing system information: %w", err) + if err := tr.Delete(a.systemName); err != nil { + return fmt.Errorf("failed to remove system %s from secret storage: %w", a.systemName, err) + } + delete(a.provider.Systems, a.systemName) + if err := writeConfig(a.provider, a.configPath); err != nil { + return fmt.Errorf("failed to write config: %w", err) } - return nil } -func (a *Auth0) persistConfig() error { - dir := filepath.Dir(a.Path) - if _, err := os.Stat(dir); os.IsNotExist(err) { - if err := os.MkdirAll(dir, 0700); err != nil { - return err - } - } - - buf, err := a.configToJson(&a.config) - if err != nil { - return err - } - - if err := os.WriteFile(a.Path, buf, 0600); err != nil { +func writeConfig(provider auth0Provider, path string) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { return err } - - return nil -} - -func (a *Auth0) configToJson(cfg *config) ([]byte, error) { - cfg.Version = 1 - r := configJsonFormat{ - Version: 1, + version := 1 + provider.Version = version + r := config{ + Version: version, Providers: providers{ - Config: *cfg, + Auth0: provider, }, } - return json.MarshalIndent(r, "", " ") -} - -func (a *Auth0) jsonToConfig(buf []byte) (*config, error) { - r := configJsonFormat{} - if err := json.Unmarshal(buf, &r); err != nil { - return nil, err - } - cfg := r.Providers.Config - if cfg.Systems != nil { - for n, s := range cfg.Systems { - s.Name = n - } + jsonConfig, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err } - return &cfg, nil + return os.WriteFile(path, jsonConfig, 0600) } -func (a *Auth0) init() error { - a.initOnce.Do(func() { - if a.errOnce = a.initContext(); a.errOnce != nil { - return +func readConfig(path string) (auth0Provider, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return auth0Provider{}, nil } - }) - return a.errOnce -} - -func (a *Auth0) initContext() (err error) { - if _, err := os.Stat(a.Path); os.IsNotExist(err) { - return errUnauthenticated + return auth0Provider{}, err } - - var buf []byte - if buf, err = os.ReadFile(a.Path); err != nil { - return err + defer f.Close() + cfg := config{} + if err := json.NewDecoder(f).Decode(&cfg); err != nil { + return auth0Provider{}, err } - - cfg, err := a.jsonToConfig(buf) - if err != nil { - return err + auth0Provider := cfg.Providers.Auth0 + if auth0Provider.Systems == nil { + auth0Provider.Systems = make(map[string]Credentials) } - a.config = *cfg - return nil + return auth0Provider, nil } diff --git a/client/go/auth/auth0/auth0_test.go b/client/go/auth/auth0/auth0_test.go new file mode 100644 index 00000000000..2616a62ef55 --- /dev/null +++ b/client/go/auth/auth0/auth0_test.go @@ -0,0 +1,99 @@ +package auth0 + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vespa-engine/vespa/client/go/mock" +) + +func TestConfigWriting(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config") + httpClient := mock.HTTPClient{} + flowConfigResponse := `{ + "audience": "https://example.com/api/v2/", + "client-id": "some-id", + "device-code-endpoint": "https://example.com/oauth/device/code", + "oauth-token-endpoint": "https://example.com/oauth/token" +}` + httpClient.NextResponseString(200, flowConfigResponse) + client, err := newClient(&httpClient, configPath, "public", "http://example.com") + require.Nil(t, err) + assert.Equal(t, "https://example.com/api/v2/", client.Authenticator.Audience) + assert.Equal(t, "some-id", client.Authenticator.ClientID) + assert.Equal(t, "https://example.com/oauth/device/code", client.Authenticator.DeviceCodeEndpoint) + assert.Equal(t, "https://example.com/oauth/token", client.Authenticator.OauthTokenEndpoint) + + creds1 := Credentials{ + AccessToken: "some-token", + Scopes: []string{"foo", "bar"}, + ExpiresAt: time.Date(2022, 03, 01, 15, 45, 50, 0, time.UTC), + } + require.Nil(t, client.WriteCredentials(creds1)) + expected := `{ + "version": 1, + "providers": { + "auth0": { + "version": 1, + "systems": { + "public": { + "access_token": "some-token", + "scopes": [ + "foo", + "bar" + ], + "expires_at": "2022-03-01T15:45:50Z" + } + } + } + } +}` + assertConfig(t, expected, configPath) + + // Switch to another system + httpClient.NextResponseString(200, flowConfigResponse) + client, err = newClient(&httpClient, configPath, "publiccd", "http://example.com") + require.Nil(t, err) + creds2 := Credentials{ + AccessToken: "another-token", + Scopes: []string{"baz"}, + ExpiresAt: time.Date(2022, 03, 01, 15, 45, 50, 0, time.UTC), + } + require.Nil(t, client.WriteCredentials(creds2)) + expected = `{ + "version": 1, + "providers": { + "auth0": { + "version": 1, + "systems": { + "public": { + "access_token": "some-token", + "scopes": [ + "foo", + "bar" + ], + "expires_at": "2022-03-01T15:45:50Z" + }, + "publiccd": { + "access_token": "another-token", + "scopes": [ + "baz" + ], + "expires_at": "2022-03-01T15:45:50Z" + } + } + } + } +}` + assertConfig(t, expected, configPath) +} + +func assertConfig(t *testing.T, expected, path string) { + data, err := os.ReadFile(path) + require.Nil(t, err) + assert.Equal(t, expected, string(data)) +} diff --git a/client/go/cmd/config.go b/client/go/cmd/config.go index 18af9c89771..447bad14444 100644 --- a/client/go/cmd/config.go +++ b/client/go/cmd/config.go @@ -281,8 +281,8 @@ func (c *Config) useAPIKey(cli *CLI, system vespa.System, tenantName string) boo // If no Auth0 token is created, fall back to tenant api key, but warn that this functionality is deprecated // TODO: Remove this when users have had time to migrate over to Auth0 device flow authentication if !cli.isCI() { - a, err := auth0.GetAuth0(c.authConfigPath(), system.Name, system.URL) - if err != nil || !a.HasSystem() { + a, err := auth0.New(c.authConfigPath(), system.Name, system.URL) + if err != nil || !a.HasCredentials() { cli.printWarning("Use of API key is deprecated", "Authenticate with Auth0 instead: 'vespa auth login'") return util.PathExists(c.apiKeyPath(tenantName)) } diff --git a/client/go/cmd/curl.go b/client/go/cmd/curl.go index e06942196a6..41e37f5319b 100644 --- a/client/go/cmd/curl.go +++ b/client/go/cmd/curl.go @@ -4,11 +4,11 @@ package cmd import ( "fmt" "log" + "net/http" "os" "strings" "github.com/spf13/cobra" - "github.com/vespa-engine/vespa/client/go/auth/auth0" "github.com/vespa-engine/vespa/client/go/curl" "github.com/vespa-engine/vespa/client/go/vespa" ) @@ -50,10 +50,8 @@ $ vespa curl -- -v --data-urlencode "yql=select * from music where album contain } switch curlService { case vespa.DeployService: - if target.Type() == vespa.TargetCloud { - if err := addCloudAuth0Authentication(target.Deployment().System, cli.config, c); err != nil { - return err - } + if err := addAccessToken(c, target); err != nil { + return err } case vespa.DocumentService, vespa.QueryService: c.PrivateKey = service.TLSOptions.PrivateKeyFile @@ -77,17 +75,19 @@ $ vespa curl -- -v --data-urlencode "yql=select * from music where album contain return cmd } -func addCloudAuth0Authentication(system vespa.System, cfg *Config, c *curl.Command) error { - a, err := auth0.GetAuth0(cfg.authConfigPath(), system.Name, system.URL) - if err != nil { - return err +func addAccessToken(cmd *curl.Command, target vespa.Target) error { + if target.Type() != vespa.TargetCloud { + return nil } - - authSystem, err := a.PrepareSystem(auth0.ContextWithCancel()) - if err != nil { + req := http.Request{} + if err := target.SignRequest(&req, ""); err != nil { return err } - c.Header("Authorization", "Bearer "+authSystem.AccessToken) + headerValue := req.Header.Get("Authorization") + if headerValue == "" { + return fmt.Errorf("no authorization header added when signing request") + } + cmd.Header("Authorization", headerValue) return nil } diff --git a/client/go/cmd/login.go b/client/go/cmd/login.go index 5cf471ed8db..f9ea90a6e55 100644 --- a/client/go/cmd/login.go +++ b/client/go/cmd/login.go @@ -34,7 +34,7 @@ func newLoginCmd(cli *CLI) *cobra.Command { if err != nil { return err } - a, err := auth0.GetAuth0(cli.config.authConfigPath(), system.Name, system.URL) + a, err := auth0.New(cli.config.authConfigPath(), system.Name, system.URL) if err != nil { return err } @@ -79,15 +79,13 @@ func newLoginCmd(cli *CLI) *cobra.Command { log.Println("Could not store the refresh token locally, please expect to login again once your access token expired.") } - s := auth0.System{ - Name: system.Name, + creds := auth0.Credentials{ AccessToken: res.AccessToken, ExpiresAt: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), Scopes: auth.RequiredScopes(), } - err = a.AddSystem(&s) - if err != nil { - return fmt.Errorf("could not add system to config: %w", err) + if err := a.WriteCredentials(creds); err != nil { + return fmt.Errorf("failed to write credentials: %w", err) } return err }, diff --git a/client/go/cmd/logout.go b/client/go/cmd/logout.go index c0a6b60ab2e..de9005c34d8 100644 --- a/client/go/cmd/logout.go +++ b/client/go/cmd/logout.go @@ -24,11 +24,11 @@ func newLogoutCmd(cli *CLI) *cobra.Command { if err != nil { return err } - a, err := auth0.GetAuth0(cli.config.authConfigPath(), system.Name, system.URL) + a, err := auth0.New(cli.config.authConfigPath(), system.Name, system.URL) if err != nil { return err } - if err := a.RemoveSystem(system.Name); err != nil { + if err := a.RemoveCredentials(); err != nil { return err } diff --git a/client/go/vespa/target_cloud.go b/client/go/vespa/target_cloud.go index 697b0f23ba1..696d5109015 100644 --- a/client/go/vespa/target_cloud.go +++ b/client/go/vespa/target_cloud.go @@ -205,15 +205,18 @@ func (t *cloudTarget) CheckVersion(clientVersion version.Version) error { } func (t *cloudTarget) addAuth0AccessToken(request *http.Request) error { - a, err := auth0.GetAuth0(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL) + client, err := auth0.New(t.apiOptions.AuthConfigPath, t.apiOptions.System.Name, t.apiOptions.System.URL) if err != nil { return err } - system, err := a.PrepareSystem(auth0.ContextWithCancel()) + accessToken, err := client.GetAccessToken() if err != nil { return err } - request.Header.Set("Authorization", "Bearer "+system.AccessToken) + if request.Header == nil { + request.Header = make(http.Header) + } + request.Header.Set("Authorization", "Bearer "+accessToken) return nil } |