summaryrefslogtreecommitdiffstats
path: root/client/go/auth
diff options
context:
space:
mode:
Diffstat (limited to 'client/go/auth')
-rw-r--r--client/go/auth/auth0/auth0.go433
-rw-r--r--client/go/auth/zts/zts.go58
-rw-r--r--client/go/auth/zts/zts_test.go30
3 files changed, 521 insertions, 0 deletions
diff --git a/client/go/auth/auth0/auth0.go b/client/go/auth/auth0/auth0.go
new file mode 100644
index 00000000000..52ba3f085a4
--- /dev/null
+++ b/client/go/auth/auth0/auth0.go
@@ -0,0 +1,433 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package auth0
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/lestrrat-go/jwx/jwt"
+ "github.com/pkg/browser"
+ "github.com/vespa-engine/vespa/client/go/auth"
+ "github.com/vespa-engine/vespa/client/go/util"
+)
+
+const accessTokenExpThreshold = 5 * time.Minute
+
+var errUnauthenticated = errors.New("not logged in. Try 'vespa auth login'")
+
+type configJsonFormat struct {
+ Version int `json:"version"`
+ Providers providers `json:"providers"`
+}
+
+type providers struct {
+ Config config `json:"auth0"`
+}
+
+type config struct {
+ Version int `json:"version"`
+ Systems map[string]*System `json:"systems"`
+}
+
+type System struct {
+ Name string `json:"-"`
+ AccessToken string `json:"access_token,omitempty"`
+ Scopes []string `json:"scopes,omitempty"`
+ ExpiresAt time.Time `json:"expires_at"`
+}
+
+type Auth0 struct {
+ Authenticator *auth.Authenticator
+ system string
+ systemApiUrl string
+ initOnce sync.Once
+ errOnce error
+ Path string
+ config config
+}
+
+type authCfg struct {
+ Audience string `json:"audience"`
+ ClientID string `json:"client-id"`
+ DeviceCodeEndpoint string `json:"device-code-endpoint"`
+ OauthTokenEndpoint string `json:"oauth-token-endpoint"`
+}
+
+func ContextWithCancel() context.Context {
+ ctx, cancel := context.WithCancel(context.Background())
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, os.Interrupt)
+ go func() {
+ <-ch
+ defer cancel()
+ os.Exit(0)
+ }()
+ return ctx
+}
+
+// GetAuth0 will try to initialize the config context, as well as figure out if
+// there's a readily available system.
+func GetAuth0(configPath string, systemName string, systemApiUrl string) (*Auth0, error) {
+ a := Auth0{}
+ a.Path = configPath
+ a.system = systemName
+ a.systemApiUrl = systemApiUrl
+ c, err := a.getDeviceFlowConfig()
+ if err != nil {
+ return nil, fmt.Errorf("cannot get auth config: %w", err)
+ }
+ a.Authenticator = &auth.Authenticator{
+ Audience: c.Audience,
+ ClientID: c.ClientID,
+ DeviceCodeEndpoint: c.DeviceCodeEndpoint,
+ OauthTokenEndpoint: c.OauthTokenEndpoint,
+ }
+ return &a, nil
+}
+
+func (a *Auth0) getDeviceFlowConfig() (authCfg, error) {
+ systemApiUrl, _ := url.Parse(a.systemApiUrl + "/auth0/v1/device-flow-config")
+ r, err := http.Get(systemApiUrl.String())
+ if err != nil {
+ return authCfg{}, fmt.Errorf("cannot get auth config: %w", err)
+ }
+ defer r.Body.Close()
+ var res authCfg
+ err = json.NewDecoder(r.Body).Decode(&res)
+ if err != nil {
+ return authCfg{}, fmt.Errorf("cannot decode response: %w", err)
+ }
+ return res, nil
+}
+
+// IsLoggedIn encodes the domain logic for determining whether we're
+// logged in. This might check our config storage, or just in memory.
+func (a *Auth0) IsLoggedIn() bool {
+ // No need to check errors for initializing context.
+ _ = a.init()
+
+ if a.system == "" {
+ return false
+ }
+
+ // Parse the access token for the system.
+ token, err := jwt.ParseString(a.config.Systems[a.system].AccessToken)
+ if err != nil {
+ return false
+ }
+
+ // Check if token is valid.
+ // TODO: Choose issuer based on system
+ if err = jwt.Validate(token, jwt.WithIssuer("https://vespa-cd.auth0.com/")); err != nil {
+ return false
+ }
+
+ return true
+}
+
+// PrepareSystem loads the System, refreshing its token if necessary.
+// The System access token needs a refresh if:
+// 1. the System scopes are different from the currently required scopes - (auth0 changes).
+// 2. the access token is expired.
+func (a *Auth0) PrepareSystem(ctx context.Context) (*System, error) {
+ if err := a.init(); err != nil {
+ return nil, err
+ }
+ s, err := a.getSystem()
+ if err != nil {
+ return nil, err
+ }
+
+ if s.AccessToken == "" || scopesChanged(s) {
+ s, err = RunLogin(ctx, a, true)
+ if err != nil {
+ return nil, err
+ }
+ } else if isExpired(s.ExpiresAt, accessTokenExpThreshold) {
+ // check if the stored access token is expired:
+ // use the refresh token to get a new access token:
+ tr := &auth.TokenRetriever{
+ Authenticator: a.Authenticator,
+ Secrets: &auth.Keyring{},
+ Client: http.DefaultClient,
+ }
+
+ res, err := tr.Refresh(ctx, a.system)
+ if err != nil {
+ return nil, fmt.Errorf("failed to renew access token: %w: %s", err, "re-authenticate with 'vespa auth login'")
+ } else {
+ // persist the updated system with renewed access token
+ s.AccessToken = res.AccessToken
+ s.ExpiresAt = time.Now().Add(
+ time.Duration(res.ExpiresIn) * time.Second,
+ )
+
+ err = a.AddSystem(s)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return s, nil
+}
+
+// isExpired is true if now() + a threshold is after the given date
+func isExpired(t time.Time, threshold time.Duration) bool {
+ return time.Now().Add(threshold).After(t)
+}
+
+// scopesChanged compare the System scopes
+// with the currently required scopes.
+func scopesChanged(s *System) bool {
+ want := auth.RequiredScopes()
+ got := s.Scopes
+
+ sort.Strings(want)
+ sort.Strings(got)
+
+ if (want == nil) != (got == nil) {
+ return true
+ }
+
+ if len(want) != len(got) {
+ return true
+ }
+
+ for i := range s.Scopes {
+ if want[i] != got[i] {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (a *Auth0) getSystem() (*System, error) {
+ if err := a.init(); err != nil {
+ return nil, err
+ }
+
+ s, ok := a.config.Systems[a.system]
+ if !ok {
+ return nil, fmt.Errorf("unable to find system: %s; run 'vespa auth login' to configure a new system", a.system)
+ }
+
+ return s, nil
+}
+
+// HasSystem checks if the system is configured
+// TODO: Used to print deprecation warning if we fall back to use tenant API key.
+// Remove when this is not longer needed.
+func (a *Auth0) HasSystem() bool {
+ if _, err := a.getSystem(); err != nil {
+ return false
+ }
+ return true
+}
+
+// AddSystem assigns an existing, or new System. This is expected to be called
+// after a login has completed.
+func (a *Auth0) AddSystem(s *System) error {
+ _ = a.init()
+
+ // If we're dealing with an empty file, we'll need to initialize this map.
+ if a.config.Systems == nil {
+ a.config.Systems = map[string]*System{}
+ }
+
+ a.config.Systems[a.system] = s
+
+ if err := a.persistConfig(); err != nil {
+ return fmt.Errorf("unexpected error persisting config: %w", err)
+ }
+
+ return nil
+}
+
+func (a *Auth0) removeSystem(s string) error {
+ _ = a.init()
+
+ // If we're dealing with an empty file, we'll need to initialize this map.
+ if a.config.Systems == nil {
+ a.config.Systems = map[string]*System{}
+ }
+
+ delete(a.config.Systems, s)
+
+ if err := a.persistConfig(); err != nil {
+ return fmt.Errorf("unexpected error persisting config: %w", err)
+ }
+
+ tr := &auth.TokenRetriever{Secrets: &auth.Keyring{}}
+ if err := tr.Delete(s); err != nil {
+ return fmt.Errorf("unexpected error clearing system information: %w", err)
+ }
+
+ return nil
+}
+
+func (a *Auth0) persistConfig() error {
+ dir := filepath.Dir(a.Path)
+ if _, err := os.Stat(dir); os.IsNotExist(err) {
+ if err := os.MkdirAll(dir, 0700); err != nil {
+ return err
+ }
+ }
+
+ buf, err := a.configToJson(&a.config)
+ if err != nil {
+ return err
+ }
+
+ if err := ioutil.WriteFile(a.Path, buf, 0600); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (a *Auth0) configToJson(cfg *config) ([]byte, error) {
+ cfg.Version = 1
+ r := configJsonFormat{
+ Version: 1,
+ Providers: providers{
+ Config: *cfg,
+ },
+ }
+ return json.MarshalIndent(r, "", " ")
+}
+
+func (a *Auth0) jsonToConfig(buf []byte) (*config, error) {
+ r := configJsonFormat{}
+ if err := json.Unmarshal(buf, &r); err != nil {
+ return nil, err
+ }
+ cfg := r.Providers.Config
+ if cfg.Systems != nil {
+ for n, s := range cfg.Systems {
+ s.Name = n
+ }
+ }
+ return &cfg, nil
+}
+
+func (a *Auth0) init() error {
+ a.initOnce.Do(func() {
+ if a.errOnce = a.initContext(); a.errOnce != nil {
+ return
+ }
+ })
+ return a.errOnce
+}
+
+func (a *Auth0) initContext() (err error) {
+ if _, err := os.Stat(a.Path); os.IsNotExist(err) {
+ return errUnauthenticated
+ }
+
+ var buf []byte
+ if buf, err = ioutil.ReadFile(a.Path); err != nil {
+ return err
+ }
+
+ cfg, err := a.jsonToConfig(buf)
+ if err != nil {
+ return err
+ }
+ a.config = *cfg
+ return nil
+}
+
+// RunLogin runs the login flow guiding the user through the process
+// by showing the login instructions, opening the browser.
+// Use `expired` to run the login from other commands setup:
+// this will only affect the messages.
+func RunLogin(ctx context.Context, a *Auth0, expired bool) (*System, error) {
+ if expired {
+ fmt.Println("Please sign in to re-authorize the CLI.")
+ }
+
+ state, err := a.Authenticator.Start(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("could not start the authentication process: %w", err)
+ }
+
+ fmt.Printf("Your Device Confirmation code is: %s\n\n", state.UserCode)
+
+ fmt.Println("If you prefer, you can open the URL directly for verification")
+ fmt.Printf("Your Verification URL: %s\n\n", state.VerificationURI)
+
+ fmt.Println("Press Enter to open the browser to log in or ^C to quit...")
+ fmt.Scanln()
+
+ err = browser.OpenURL(state.VerificationURI)
+
+ if err != nil {
+ fmt.Printf("Couldn't open the URL, please do it manually: %s.", state.VerificationURI)
+ }
+
+ var res auth.Result
+ err = util.Spinner(os.Stderr, "Waiting for login to complete in browser ...", func() error {
+ res, err = a.Authenticator.Wait(ctx, state)
+ return err
+ })
+
+ if err != nil {
+ return nil, fmt.Errorf("login error: %w", err)
+ }
+
+ fmt.Print("\n")
+ fmt.Println("Successfully logged in.")
+ fmt.Print("\n")
+
+ // store the refresh token
+ secretsStore := &auth.Keyring{}
+ err = secretsStore.Set(auth.SecretsNamespace, a.system, res.RefreshToken)
+ if err != nil {
+ // log the error but move on
+ fmt.Println("Could not store the refresh token locally, please expect to login again once your access token expired.")
+ }
+
+ s := System{
+ Name: a.system,
+ AccessToken: res.AccessToken,
+ ExpiresAt: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second),
+ Scopes: auth.RequiredScopes(),
+ }
+ err = a.AddSystem(&s)
+ if err != nil {
+ return nil, fmt.Errorf("could not add system to config: %w", err)
+ }
+
+ return &s, nil
+}
+
+func RunLogout(a *Auth0) error {
+ s, err := a.getSystem()
+ if err != nil {
+ return err
+ }
+
+ if err := a.removeSystem(s.Name); err != nil {
+ return err
+ }
+
+ fmt.Print("\n")
+ fmt.Println("Successfully logged out.")
+ fmt.Print("\n")
+
+ return nil
+}
diff --git a/client/go/auth/zts/zts.go b/client/go/auth/zts/zts.go
new file mode 100644
index 00000000000..d288c2050d9
--- /dev/null
+++ b/client/go/auth/zts/zts.go
@@ -0,0 +1,58 @@
+package zts
+
+import (
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/vespa-engine/vespa/client/go/util"
+)
+
+const DefaultURL = "https://zts.athenz.ouroath.com:4443"
+
+// Client is a client for Athenz ZTS, an authentication token service.
+type Client struct {
+ client util.HTTPClient
+ tokenURL *url.URL
+}
+
+// NewClient creates a new client for an Athenz ZTS service located at serviceURL.
+func NewClient(serviceURL string, client util.HTTPClient) (*Client, error) {
+ tokenURL, err := url.Parse(serviceURL)
+ if err != nil {
+ return nil, err
+ }
+ tokenURL.Path = "/zts/v1/oauth2/token"
+ return &Client{tokenURL: tokenURL, client: client}, nil
+}
+
+// AccessToken returns an access token within the given domain, using certificate to authenticate with ZTS.
+func (c *Client) AccessToken(domain string, certificate tls.Certificate) (string, error) {
+ data := fmt.Sprintf("grant_type=client_credentials&scope=%s:domain", domain)
+ req, err := http.NewRequest("POST", c.tokenURL.String(), strings.NewReader(data))
+ if err != nil {
+ return "", err
+ }
+ c.client.UseCertificate([]tls.Certificate{certificate})
+ response, err := c.client.Do(req, 10*time.Second)
+ if err != nil {
+ return "", err
+ }
+ defer response.Body.Close()
+
+ if response.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("got status %d from %s", response.StatusCode, c.tokenURL.String())
+ }
+ var ztsResponse struct {
+ AccessToken string `json:"access_token"`
+ }
+ dec := json.NewDecoder(response.Body)
+ if err := dec.Decode(&ztsResponse); err != nil {
+ return "", err
+ }
+ return ztsResponse.AccessToken, nil
+}
diff --git a/client/go/auth/zts/zts_test.go b/client/go/auth/zts/zts_test.go
new file mode 100644
index 00000000000..0eec085aadb
--- /dev/null
+++ b/client/go/auth/zts/zts_test.go
@@ -0,0 +1,30 @@
+package zts
+
+import (
+ "crypto/tls"
+ "testing"
+
+ "github.com/vespa-engine/vespa/client/go/mock"
+)
+
+func TestAccessToken(t *testing.T) {
+ httpClient := mock.HTTPClient{}
+ client, err := NewClient("http://example.com", &httpClient)
+ if err != nil {
+ t.Fatal(err)
+ }
+ httpClient.NextResponse(400, `{"message": "bad request"}`)
+ _, err = client.AccessToken("vespa.vespa", tls.Certificate{})
+ if err == nil {
+ t.Fatal("want error for non-ok response status")
+ }
+ httpClient.NextResponse(200, `{"access_token": "foo bar"}`)
+ token, err := client.AccessToken("vespa.vespa", tls.Certificate{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "foo bar"
+ if token != want {
+ t.Errorf("got %q, want %q", token, want)
+ }
+}