diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index c196e389c0..ce5c0f45c6 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,8 @@ ### CLI +* Moved file-based OAuth token cache management from the SDK to the CLI. No user-visible change; part of a three-PR sequence that makes the CLI the sole owner of its token cache. + ### Bundles ### Dependency updates diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 7157f797b7..21c3aa3608 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -10,6 +10,7 @@ import ( "time" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/auth/storage" "github.com/databricks/cli/libs/browser" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" @@ -21,6 +22,7 @@ import ( "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/spf13/cobra" "golang.org/x/oauth2" ) @@ -189,13 +191,18 @@ a new profile is created. return err } + tokenCache, err := storage.NewFileTokenCache(ctx) + if err != nil { + return fmt.Errorf("opening token cache: %w", err) + } + // If no host is available from any source, use the discovery flow // via login.databricks.com. if shouldUseDiscovery(authArguments.Host, args, existingProfile) { if err := validateDiscoveryFlagCompatibility(cmd); err != nil { return err } - return discoveryLogin(ctx, &defaultDiscoveryClient{}, profileName, loginTimeout, scopes, existingProfile, getBrowserFunc(cmd)) + return discoveryLogin(ctx, &defaultDiscoveryClient{}, tokenCache, profileName, loginTimeout, scopes, existingProfile, getBrowserFunc(cmd)) } // Load unified host flag from the profile if not explicitly set via CLI flag. @@ -228,6 +235,7 @@ a new profile is created. return err } persistentAuthOpts := []u2m.PersistentAuthOption{ + u2m.WithTokenCache(storage.NewDualWritingTokenCache(tokenCache, oauthArgument)), u2m.WithOAuthArgument(oauthArgument), u2m.WithBrowser(getBrowserFunc(cmd)), } @@ -562,7 +570,7 @@ func validateDiscoveryFlagCompatibility(cmd *cobra.Command) error { // discoveryLogin runs the login.databricks.com discovery flow. The user // authenticates in the browser, selects a workspace, and the CLI receives // the workspace host from the OAuth callback's iss parameter. -func discoveryLogin(ctx context.Context, dc discoveryClient, profileName string, timeout time.Duration, scopes string, existingProfile *profile.Profile, browserFunc func(string) error) error { +func discoveryLogin(ctx context.Context, dc discoveryClient, tokenCache cache.TokenCache, profileName string, timeout time.Duration, scopes string, existingProfile *profile.Profile, browserFunc func(string) error) error { arg, err := dc.NewOAuthArgument(profileName) if err != nil { return discoveryErr("setting up login.databricks.com", err) @@ -574,6 +582,7 @@ func discoveryLogin(ctx context.Context, dc discoveryClient, profileName string, } opts := []u2m.PersistentAuthOption{ + u2m.WithTokenCache(storage.NewDualWritingTokenCache(tokenCache, arg)), u2m.WithOAuthArgument(arg), u2m.WithBrowser(browserFunc), u2m.WithDiscoveryLogin(), diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index 81924f027a..2b8d473f51 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -20,12 +20,19 @@ import ( "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) +// newTestTokenCache returns an in-memory token cache for tests so that +// discoveryLogin and other login helpers don't touch ~/.databricks/token-cache.json. +func newTestTokenCache() cache.TokenCache { + return &inMemoryTokenCache{Tokens: map[string]*oauth2.Token{}} +} + // logBuffer is a thread-safe bytes.Buffer for capturing log output in tests. type logBuffer struct { mu sync.Mutex @@ -623,7 +630,7 @@ func TestDiscoveryLogin_IntrospectionFailureStillSavesProfile(t *testing.T) { } ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "all-apis, ,sql,", nil, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "all-apis, ,sql,", nil, func(string) error { return nil }) require.NoError(t, err) assert.Equal(t, "https://workspace.example.com", dc.introspectHost) @@ -671,7 +678,7 @@ func TestDiscoveryLogin_AccountIDMismatchWarning(t *testing.T) { AccountID: "old-account-id", } - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) require.NoError(t, err) // Verify warning about mismatched account IDs was logged. @@ -719,7 +726,7 @@ func TestDiscoveryLogin_NoWarningWhenAccountIDsMatch(t *testing.T) { AccountID: "same-account-id", } - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) require.NoError(t, err) // No warning should be logged when account IDs match. @@ -739,7 +746,7 @@ func TestDiscoveryLogin_EmptyDiscoveredHostReturnsError(t *testing.T) { } ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) require.Error(t, err) assert.Contains(t, err.Error(), "no workspace host was discovered") } @@ -771,7 +778,7 @@ func TestDiscoveryLogin_ReloginPreservesExistingProfileScopes(t *testing.T) { // No --scopes flag (empty string), should fall back to existing profile scopes. ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) require.NoError(t, err) savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) @@ -808,7 +815,7 @@ func TestDiscoveryLogin_ExplicitScopesOverrideExistingProfile(t *testing.T) { // Explicit --scopes flag should override existing profile scopes. ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "all-apis", existingProfile, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "all-apis", existingProfile, func(string) error { return nil }) require.NoError(t, err) savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) @@ -848,7 +855,7 @@ func TestDiscoveryLogin_SPOGHostPopulatesAccountIDFromDiscovery(t *testing.T) { } ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) require.NoError(t, err) savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) @@ -883,7 +890,7 @@ func TestDiscoveryLogin_IntrospectionFallsBackWhenDiscoveryFails(t *testing.T) { } ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", nil, func(string) error { return nil }) require.NoError(t, err) savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) @@ -932,7 +939,7 @@ auth_type = databricks-cli } ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) require.NoError(t, err) savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) @@ -982,7 +989,7 @@ auth_type = databricks-cli } ctx, _ := cmdio.NewTestContextWithStdout(t.Context()) - err = discoveryLogin(ctx, dc, "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) + err = discoveryLogin(ctx, dc, newTestTokenCache(), "DISCOVERY", time.Second, "", existingProfile, func(string) error { return nil }) require.NoError(t, err) savedProfile, err := loadProfileByName(ctx, "DISCOVERY", profile.DefaultProfiler) diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go index 3beeeefec9..67829ec169 100644 --- a/cmd/auth/logout.go +++ b/cmd/auth/logout.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/auth/storage" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" "github.com/databricks/cli/libs/databrickscfg/profile" @@ -132,7 +133,7 @@ to specify it explicitly. profileName = selected } - tokenCache, err := cache.NewFileTokenCache() + tokenCache, err := storage.NewFileTokenCache(ctx) if err != nil { return fmt.Errorf("failed to open token cache, please check if the file version is up-to-date and that the file is not corrupted: %w", err) } diff --git a/cmd/auth/token.go b/cmd/auth/token.go index fbdd8811e8..d82bcb2c9a 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/auth/storage" "github.com/databricks/cli/libs/browser" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/databrickscfg" @@ -78,6 +79,11 @@ and secret is not supported.`, ctx := cmd.Context() profileName := cmd.Flag("profile").Value.String() + tokenCache, err := storage.NewFileTokenCache(ctx) + if err != nil { + return fmt.Errorf("opening token cache: %w", err) + } + t, err := loadToken(ctx, loadTokenArgs{ authArguments: authArguments, profileName: profileName, @@ -85,6 +91,7 @@ and secret is not supported.`, tokenTimeout: tokenTimeout, forceRefresh: forceRefresh, profiler: profile.DefaultProfiler, + tokenCache: tokenCache, persistentAuthOpts: nil, }) if err != nil { @@ -133,6 +140,10 @@ type loadTokenArgs struct { // profiler is the profiler to use for reading the host and account ID from the .databrickscfg file. profiler profile.Profiler + // tokenCache is the underlying TokenCache used for OAuth tokens. The caller is + // responsible for construction so that tests can substitute an in-memory cache. + tokenCache cache.TokenCache + // persistentAuthOpts are the options to pass to the persistent auth client. persistentAuthOpts []u2m.PersistentAuthOption } @@ -184,7 +195,7 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { // resolve the target through environment variables or interactive profile selection. if args.profileName == "" && args.authArguments.Host == "" && len(args.args) == 0 { var resolvedProfile string - resolvedProfile, existingProfile, err = resolveNoArgsToken(ctx, args.profiler, args.authArguments) + resolvedProfile, existingProfile, err = resolveNoArgsToken(ctx, args.profiler, args.authArguments, args.tokenCache) if err != nil { return nil, err } @@ -273,7 +284,9 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { if err != nil { return nil, err } - allArgs := append(args.persistentAuthOpts, u2m.WithOAuthArgument(oauthArgument)) + wrappedCache := storage.NewDualWritingTokenCache(args.tokenCache, oauthArgument) + allArgs := append([]u2m.PersistentAuthOption{u2m.WithTokenCache(wrappedCache)}, args.persistentAuthOpts...) + allArgs = append(allArgs, u2m.WithOAuthArgument(oauthArgument)) persistentAuth, err := u2m.NewPersistentAuth(ctx, allArgs...) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) @@ -314,7 +327,7 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { // // Returns the resolved profile name and profile (if any). The host and related // fields on authArgs are updated in place when resolved via environment variables. -func resolveNoArgsToken(ctx context.Context, profiler profile.Profiler, authArgs *auth.AuthArguments) (string, *profile.Profile, error) { +func resolveNoArgsToken(ctx context.Context, profiler profile.Profiler, authArgs *auth.AuthArguments, tokenCache cache.TokenCache) (string, *profile.Profile, error) { // Step 1: Try DATABRICKS_HOST env var (highest priority). if envHost := env.Get(ctx, "DATABRICKS_HOST"); envHost != "" { authArgs.Host = envHost @@ -363,7 +376,7 @@ func resolveNoArgsToken(ctx context.Context, profiler profile.Profiler, authArgs // Fall through — setHostAndAccountId will prompt for the host. return "", nil, nil case createNewSelected: - return runInlineLogin(ctx, profiler) + return runInlineLogin(ctx, profiler, tokenCache) default: p, err := loadProfileByName(ctx, selectedName, profiler) if err != nil { @@ -427,7 +440,7 @@ func promptForProfileSelection(ctx context.Context, profiles profile.Profiles) ( // runInlineLogin runs a minimal interactive login flow: prompts for a profile // name and host, performs the OAuth challenge, saves the profile to // .databrickscfg, and returns the new profile name and profile. -func runInlineLogin(ctx context.Context, profiler profile.Profiler) (string, *profile.Profile, error) { +func runInlineLogin(ctx context.Context, profiler profile.Profiler, tokenCache cache.TokenCache) (string, *profile.Profile, error) { profileName, err := promptForProfile(ctx, "DEFAULT") if err != nil { return "", nil, err @@ -460,6 +473,7 @@ func runInlineLogin(ctx context.Context, profiler profile.Profiler) (string, *pr return "", nil, err } persistentAuthOpts := []u2m.PersistentAuthOption{ + u2m.WithTokenCache(storage.NewDualWritingTokenCache(tokenCache, oauthArgument)), u2m.WithOAuthArgument(oauthArgument), u2m.WithBrowser(func(url string) error { return browser.Open(ctx, url) }), } diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index ec7fe2004a..3dfa4e5d21 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -221,6 +221,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -241,6 +242,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -258,6 +260,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -275,6 +278,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -292,6 +296,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -308,6 +313,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -324,6 +330,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -340,6 +347,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{"workspace-a"}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -356,6 +364,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{"workspace-a.cloud.databricks.com"}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -372,6 +381,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{"default.dev"}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -388,6 +398,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{"nonexistent.cloud.databricks.com"}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -419,6 +430,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -436,6 +448,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -454,6 +467,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -473,6 +487,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -490,6 +505,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -508,6 +524,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -526,6 +543,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -542,6 +560,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{"workspace-a"}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -557,6 +576,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: nil, }, wantErr: "no profile specified. Use --profile to specify which profile to use", @@ -569,6 +589,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profile.InMemoryProfiler{}, + tokenCache: tokenCache, persistentAuthOpts: nil, }, wantErr: "no profiles configured. Run 'databricks auth login' to create a profile", @@ -581,6 +602,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: errProfiler{err: profile.ErrNoConfiguration}, + tokenCache: tokenCache, persistentAuthOpts: nil, }, wantErr: "no profiles configured. Run 'databricks auth login' to create a profile", @@ -638,6 +660,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -658,6 +681,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -678,6 +702,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -699,6 +724,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -734,6 +760,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -750,6 +777,7 @@ func TestToken_loadToken(t *testing.T) { args: []string{}, tokenTimeout: 1 * time.Hour, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -769,6 +797,7 @@ func TestToken_loadToken(t *testing.T) { tokenTimeout: 1 * time.Hour, forceRefresh: true, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), @@ -786,6 +815,7 @@ func TestToken_loadToken(t *testing.T) { tokenTimeout: 1 * time.Hour, forceRefresh: true, profiler: profiler, + tokenCache: tokenCache, persistentAuthOpts: []u2m.PersistentAuthOption{ u2m.WithTokenCache(tokenCache), u2m.WithOAuthEndpointSupplier(&MockApiClient{}), diff --git a/libs/auth/credentials.go b/libs/auth/credentials.go index 7ab6eb2a85..a406955dd3 100644 --- a/libs/auth/credentials.go +++ b/libs/auth/credentials.go @@ -3,7 +3,9 @@ package auth import ( "context" "errors" + "fmt" + "github.com/databricks/cli/libs/auth/storage" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/config/experimental/auth" @@ -97,7 +99,7 @@ func (c CLICredentials) Configure(ctx context.Context, cfg *config.Config) (cred if err != nil { return nil, err } - ts, err := c.persistentAuth(ctx, u2m.WithOAuthArgument(oauthArg)) + ts, err := c.persistentAuth(ctx, oauthArg) if err != nil { return nil, err } @@ -107,14 +109,22 @@ func (c CLICredentials) Configure(ctx context.Context, cfg *config.Config) (cred return cp, nil } -// persistentAuth returns a token source. It is a convenience function that -// overrides the default implementation of the persistent auth client if -// an alternative implementation is provided for testing. -func (c CLICredentials) persistentAuth(ctx context.Context, opts ...u2m.PersistentAuthOption) (auth.TokenSource, error) { +// persistentAuth returns a token source. It wraps the file-backed token +// cache with a dual-writing cache so every token write (Challenge, refresh, +// discovery) mirrors to the legacy host key for cross-SDK compatibility. +// The persistentAuthFn override is used in tests. +func (c CLICredentials) persistentAuth(ctx context.Context, arg u2m.OAuthArgument) (auth.TokenSource, error) { if c.persistentAuthFn != nil { - return c.persistentAuthFn(ctx, opts...) + return c.persistentAuthFn(ctx, u2m.WithOAuthArgument(arg)) } - ts, err := u2m.NewPersistentAuth(ctx, opts...) + tc, err := storage.NewFileTokenCache(ctx) + if err != nil { + return nil, fmt.Errorf("opening token cache: %w", err) + } + ts, err := u2m.NewPersistentAuth(ctx, + u2m.WithTokenCache(storage.NewDualWritingTokenCache(tc, arg)), + u2m.WithOAuthArgument(arg), + ) if err != nil { return nil, err } diff --git a/libs/auth/storage/dual_writing_cache.go b/libs/auth/storage/dual_writing_cache.go new file mode 100644 index 0000000000..874429cf31 --- /dev/null +++ b/libs/auth/storage/dual_writing_cache.go @@ -0,0 +1,66 @@ +package storage + +import ( + "github.com/databricks/databricks-sdk-go/credentials/u2m" + u2m_cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "golang.org/x/oauth2" +) + +// DualWritingTokenCache wraps a TokenCache so that every write under the +// primary OAuth cache key is also mirrored under the legacy host-based key. +// This preserves the cross-SDK compatibility convention historically +// implemented inside PersistentAuth.dualWrite in the SDK, now moved +// caller-side per the cache-ownership split between SDK and CLI. +// +// Mirroring happens inside Store, so every SDK-internal write (Challenge, +// refresh, discovery) dual-writes without requiring each call site to invoke +// a helper explicitly. +type DualWritingTokenCache struct { + inner u2m_cache.TokenCache + arg u2m.OAuthArgument +} + +// NewDualWritingTokenCache returns a TokenCache wrapping inner that mirrors +// writes made under arg.GetCacheKey() to the argument's host key when one +// can be derived (via DiscoveryOAuthArgument.GetDiscoveredHost or +// HostCacheKeyProvider.GetHostCacheKey). +func NewDualWritingTokenCache(inner u2m_cache.TokenCache, arg u2m.OAuthArgument) *DualWritingTokenCache { + return &DualWritingTokenCache{inner: inner, arg: arg} +} + +// Store implements [u2m_cache.TokenCache]. Writes under the primary key are +// also mirrored under the host key (when distinct); writes under any other +// key pass through unchanged so that a Store(hostKey, t) from an older SDK +// that still dual-writes internally does not recursively re-expand. +func (c *DualWritingTokenCache) Store(key string, t *oauth2.Token) error { + if err := c.inner.Store(key, t); err != nil { + return err + } + primaryKey := c.arg.GetCacheKey() + if key != primaryKey { + return nil + } + hostKey := hostCacheKey(c.arg) + if hostKey == "" || hostKey == primaryKey { + return nil + } + return c.inner.Store(hostKey, t) +} + +// Lookup implements [u2m_cache.TokenCache]; delegates to the inner cache. +func (c *DualWritingTokenCache) Lookup(key string) (*oauth2.Token, error) { + return c.inner.Lookup(key) +} + +// hostCacheKey mirrors the SDK's former PersistentAuth.hostCacheKey: +// discovery arguments expose the host via GetDiscoveredHost (populated by +// Challenge); static arguments expose it via HostCacheKeyProvider. +func hostCacheKey(arg u2m.OAuthArgument) string { + if discoveryArg, ok := arg.(u2m.DiscoveryOAuthArgument); ok { + return discoveryArg.GetDiscoveredHost() + } + if hcp, ok := arg.(u2m.HostCacheKeyProvider); ok { + return hcp.GetHostCacheKey() + } + return "" +} diff --git a/libs/auth/storage/dual_writing_cache_test.go b/libs/auth/storage/dual_writing_cache_test.go new file mode 100644 index 0000000000..884e7285e9 --- /dev/null +++ b/libs/auth/storage/dual_writing_cache_test.go @@ -0,0 +1,169 @@ +package storage + +import ( + "sync" + "testing" + + "github.com/databricks/databricks-sdk-go/credentials/u2m" + u2m_cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +// memoryCache is a minimal in-memory TokenCache used only by wrapper tests. +type memoryCache struct { + mu sync.Mutex + tokens map[string]*oauth2.Token +} + +func newMemoryCache() *memoryCache { + return &memoryCache{tokens: map[string]*oauth2.Token{}} +} + +func (c *memoryCache) Store(key string, t *oauth2.Token) error { + c.mu.Lock() + defer c.mu.Unlock() + if t == nil { + delete(c.tokens, key) + return nil + } + c.tokens[key] = t + return nil +} + +func (c *memoryCache) Lookup(key string) (*oauth2.Token, error) { + c.mu.Lock() + defer c.mu.Unlock() + t, ok := c.tokens[key] + if !ok { + return nil, u2m_cache.ErrNotFound + } + return t, nil +} + +// plainArg implements OAuthArgument only, exercising the "no host key" branch. +type plainArg struct { + key string +} + +func (a plainArg) GetCacheKey() string { return a.key } + +// hostArg implements HostCacheKeyProvider so the wrapper mirrors the token +// to the configured host key. +type hostArg struct { + key string + hostKey string +} + +func (a hostArg) GetCacheKey() string { return a.key } +func (a hostArg) GetHostCacheKey() string { return a.hostKey } + +func TestDualWritingCacheStorePrimaryMirrorsHost(t *testing.T) { + inner := newMemoryCache() + arg := hostArg{key: "profile-a", hostKey: "https://example.databricks.com"} + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc", RefreshToken: "r"} + + require.NoError(t, c.Store("profile-a", tok)) + + primary, err := inner.Lookup("profile-a") + require.NoError(t, err) + assert.Equal(t, tok, primary) + + host, err := inner.Lookup("https://example.databricks.com") + require.NoError(t, err) + assert.Equal(t, tok, host) +} + +func TestDualWritingCacheStoreNonPrimaryDoesNotMirror(t *testing.T) { + // An older SDK still running its internal dualWrite will follow up the + // primary Store with a Store(hostKey, t). The wrapper must pass that + // second write through without re-expanding into another pair. + inner := newMemoryCache() + arg := hostArg{key: "profile-a", hostKey: "https://example.databricks.com"} + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc"} + + require.NoError(t, c.Store("https://example.databricks.com", tok)) + + host, err := inner.Lookup("https://example.databricks.com") + require.NoError(t, err) + assert.Equal(t, tok, host) + _, err = inner.Lookup("profile-a") + require.ErrorIs(t, err, u2m_cache.ErrNotFound) +} + +func TestDualWritingCacheStoreNoHostKey(t *testing.T) { + inner := newMemoryCache() + arg := plainArg{key: "profile-a"} + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc"} + + require.NoError(t, c.Store("profile-a", tok)) + + got, err := inner.Lookup("profile-a") + require.NoError(t, err) + assert.Equal(t, tok, got) + assert.Len(t, inner.tokens, 1) +} + +func TestDualWritingCacheStoreHostKeyEqualsPrimary(t *testing.T) { + inner := newMemoryCache() + arg := hostArg{key: "https://example.databricks.com", hostKey: "https://example.databricks.com"} + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc"} + + require.NoError(t, c.Store("https://example.databricks.com", tok)) + + assert.Len(t, inner.tokens, 1) +} + +func TestDualWritingCacheDiscoveryArgWithDiscoveredHost(t *testing.T) { + inner := newMemoryCache() + arg, err := u2m.NewBasicDiscoveryOAuthArgument("profile-a") + require.NoError(t, err) + arg.SetDiscoveredHost("https://example.databricks.com") + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc"} + + require.NoError(t, c.Store("profile-a", tok)) + + primary, err := inner.Lookup("profile-a") + require.NoError(t, err) + assert.Equal(t, tok, primary) + + host, err := inner.Lookup("https://example.databricks.com") + require.NoError(t, err) + assert.Equal(t, tok, host) +} + +func TestDualWritingCacheDiscoveryArgWithEmptyDiscoveredHost(t *testing.T) { + inner := newMemoryCache() + arg, err := u2m.NewBasicDiscoveryOAuthArgument("profile-a") + require.NoError(t, err) + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc"} + + require.NoError(t, c.Store("profile-a", tok)) + + assert.Len(t, inner.tokens, 1) + primary, err := inner.Lookup("profile-a") + require.NoError(t, err) + assert.Equal(t, tok, primary) +} + +func TestDualWritingCacheLookupDelegates(t *testing.T) { + inner := newMemoryCache() + arg := hostArg{key: "profile-a", hostKey: "https://example.databricks.com"} + c := NewDualWritingTokenCache(inner, arg) + tok := &oauth2.Token{AccessToken: "abc"} + require.NoError(t, inner.Store("profile-a", tok)) + + got, err := c.Lookup("profile-a") + require.NoError(t, err) + assert.Equal(t, tok, got) + + _, err = c.Lookup("missing") + require.ErrorIs(t, err, u2m_cache.ErrNotFound) +} diff --git a/libs/auth/storage/file_cache.go b/libs/auth/storage/file_cache.go new file mode 100644 index 0000000000..f64e233b01 --- /dev/null +++ b/libs/auth/storage/file_cache.go @@ -0,0 +1,218 @@ +package storage + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + + "github.com/databricks/cli/libs/env" + u2m_cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "golang.org/x/oauth2" +) + +const ( + // tokenCacheFile is the path of the default token cache, relative to the + // user's home directory. + tokenCacheFilePath = ".databricks/token-cache.json" + + // ownerExecReadWrite is the permission for the .databricks directory. + ownerExecReadWrite = 0o700 + + // ownerReadWrite is the permission for the token-cache.json file. + ownerReadWrite = 0o600 + + // tokenCacheVersion is the version of the token cache file format. + // + // Version 1 format: + // + // { + // "version": 1, + // "tokens": { + // "": { + // "access_token": "", + // "token_type": "", + // "refresh_token": "", + // "expiry": "" + // } + // } + // } + tokenCacheVersion = 1 +) + +// tokenCacheFile is the format of the token cache file. +type tokenCacheFile struct { + Version int `json:"version"` + Tokens map[string]*oauth2.Token `json:"tokens"` +} + +type FileTokenCacheOption func(*fileTokenCache) + +func WithFileLocation(fileLocation string) FileTokenCacheOption { + return func(c *fileTokenCache) { + c.fileLocation = fileLocation + } +} + +// fileTokenCache caches tokens in "~/.databricks/token-cache.json". fileTokenCache +// implements the TokenCache interface. +type fileTokenCache struct { + fileLocation string + + // locker protects the token cache file from concurrent reads and writes. + locker sync.Mutex +} + +// NewFileTokenCache creates a new FileTokenCache. By default, the cache is +// stored in "~/.databricks/token-cache.json". The cache file is created if it +// does not already exist. The cache file is created with owner permissions +// 0600 and the directory is created with owner permissions 0700. If the cache +// file is corrupt or if its version does not match tokenCacheVersion, an error +// is returned. +func NewFileTokenCache(ctx context.Context, opts ...FileTokenCacheOption) (u2m_cache.TokenCache, error) { + c := &fileTokenCache{} + for _, opt := range opts { + opt(c) + } + if err := c.init(ctx); err != nil { + return nil, err + } + // Fail fast if the cache is not working. + if _, err := c.load(); err != nil { + return nil, fmt.Errorf("load: %w", err) + } + return c, nil +} + +// Store implements the TokenCache interface. +func (c *fileTokenCache) Store(key string, t *oauth2.Token) error { + c.locker.Lock() + defer c.locker.Unlock() + f, err := c.load() + if err != nil { + return fmt.Errorf("load: %w", err) + } + if f.Tokens == nil { + f.Tokens = map[string]*oauth2.Token{} + } + if t == nil { + delete(f.Tokens, key) + } else { + f.Tokens[key] = t + } + raw, err := json.MarshalIndent(f, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + if err := c.atomicWriteFile(raw); err != nil { + return fmt.Errorf("error storing token in local cache: %w", err) + } + return nil +} + +// Lookup implements the TokenCache interface. +func (c *fileTokenCache) Lookup(key string) (*oauth2.Token, error) { + c.locker.Lock() + defer c.locker.Unlock() + f, err := c.load() + if err != nil { + return nil, fmt.Errorf("load: %w", err) + } + t, ok := f.Tokens[key] + if !ok { + return nil, u2m_cache.ErrNotFound + } + return t, nil +} + +// init initializes the token cache file. It creates the file and directory if +// they do not already exist. +func (c *fileTokenCache) init(ctx context.Context) error { + // set the default file location + if c.fileLocation == "" { + home, err := env.UserHomeDir(ctx) + if err != nil { + return fmt.Errorf("failed loading home directory: %w", err) + } + c.fileLocation = filepath.Join(home, tokenCacheFilePath) + } + // Create the cache file if it does not exist. + if _, err := os.Stat(c.fileLocation); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("stat file: %w", err) + } + // Create the parent directories if needed. + if err := os.MkdirAll(filepath.Dir(c.fileLocation), ownerExecReadWrite); err != nil { + return fmt.Errorf("mkdir: %w", err) + } + + // Create an empty cache file. + f := &tokenCacheFile{ + Version: tokenCacheVersion, + Tokens: map[string]*oauth2.Token{}, + } + raw, err := json.MarshalIndent(f, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + if err := c.atomicWriteFile(raw); err != nil { + return fmt.Errorf("error creating token cache file: %w", err) + } + } + return nil +} + +// load loads the token cache file from disk. If the file is corrupt or if its +// version does not match tokenCacheVersion, it returns an error. +func (c *fileTokenCache) load() (*tokenCacheFile, error) { + raw, err := os.ReadFile(c.fileLocation) + if err != nil { + return nil, fmt.Errorf("read: %w", err) + } + f := &tokenCacheFile{} + if err := json.Unmarshal(raw, &f); err != nil { + return nil, fmt.Errorf("parse: %w", err) + } + if f.Version != tokenCacheVersion { + // in the later iterations we could do state upgraders, + // so that we transform token cache from v1 to v2 without + // losing the tokens and asking the user to re-authenticate. + return nil, fmt.Errorf("needs version %d, got version %d", tokenCacheVersion, f.Version) + } + return f, nil +} + +// atomicWriteFile writes data to the file atomically by first writing to a +// temporary file in the same directory and then renaming it to the target. +// This prevents corruption from interrupted writes. +func (c *fileTokenCache) atomicWriteFile(data []byte) error { + tmp, err := c.writeTmpFile(data) + if err != nil { + return err + } + defer os.Remove(tmp) + return os.Rename(tmp, c.fileLocation) +} + +func (c *fileTokenCache) writeTmpFile(data []byte) (string, error) { + tmp, err := os.CreateTemp(filepath.Dir(c.fileLocation), ".token-cache-*.tmp") + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + defer tmp.Close() + + if _, err := tmp.Write(data); err != nil { + return "", err + } + if err := tmp.Chmod(ownerReadWrite); err != nil { + return "", err + } + if err := tmp.Close(); err != nil { + return "", err + } + return tmp.Name(), nil +} diff --git a/libs/auth/storage/file_cache_test.go b/libs/auth/storage/file_cache_test.go new file mode 100644 index 0000000000..4df7576c69 --- /dev/null +++ b/libs/auth/storage/file_cache_test.go @@ -0,0 +1,67 @@ +package storage + +import ( + "os" + "path/filepath" + "testing" + + u2m_cache "github.com/databricks/databricks-sdk-go/credentials/u2m/cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func setup(t *testing.T) string { + tempHomeDir := t.TempDir() + return filepath.Join(tempHomeDir, "token-cache.json") +} + +func TestStoreAndLookup(t *testing.T) { + c, err := NewFileTokenCache(t.Context(), WithFileLocation(setup(t))) + require.NoError(t, err) + err = c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + tok, err := c.Lookup("x") + require.NoError(t, err) + assert.Equal(t, "abc", tok.AccessToken) + + _, err = c.Lookup("z") + assert.Equal(t, u2m_cache.ErrNotFound, err) +} + +func TestNoCacheFileReturnsErrNotConfigured(t *testing.T) { + l, err := NewFileTokenCache(t.Context(), WithFileLocation(setup(t))) + require.NoError(t, err) + _, err = l.Lookup("x") + assert.Equal(t, u2m_cache.ErrNotFound, err) +} + +func TestLoadCorruptFile(t *testing.T) { + f := setup(t) + err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) + require.NoError(t, err) + err = os.WriteFile(f, []byte("abc"), ownerExecReadWrite) + require.NoError(t, err) + + _, err = NewFileTokenCache(t.Context(), WithFileLocation(f)) + assert.EqualError(t, err, "load: parse: invalid character 'a' looking for beginning of value") +} + +func TestLoadWrongVersion(t *testing.T) { + f := setup(t) + err := os.MkdirAll(filepath.Dir(f), ownerExecReadWrite) + require.NoError(t, err) + err = os.WriteFile(f, []byte(`{"version": 823, "things": []}`), ownerExecReadWrite) + require.NoError(t, err) + + _, err = NewFileTokenCache(t.Context(), WithFileLocation(f)) + assert.EqualError(t, err, "load: needs version 1, got version 823") +}