From 994bfc8caaa2a86fc3b48881c56b0e255719b349 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Mon, 9 Mar 2026 17:01:45 +0800 Subject: [PATCH 1/2] feat(auth): migrate authentication token storage to shared tokenstore - Replace the custom token storage implementation with the shared sdk-go tokenstore across the codebase - Add configurable token storage backends with auto, file, and keyring modes, including a new CLI flag and environment support - Remove the bespoke file locking logic and related tests in favor of tokenstore-managed persistence - Update authentication flows to return tokens without embedding the flow, passing the flow separately where needed - Refactor browser and device flows to save and load tokens through the tokenstore interface - Simplify and modernize tests to use tokenstore APIs instead of inspecting token files directly - Introduce keyring-based secure storage with automatic fallback and user warnings when unavailable - Bump the Go version and add new dependencies required for secure token storage support Signed-off-by: Bo-Yi Wu --- browser_flow.go | 12 +++--- callback.go | 10 +++-- callback_test.go | 20 +++++---- config.go | 37 +++++++++++++++++ device_flow.go | 8 ++-- filelock.go | 68 ------------------------------ filelock_test.go | 93 ----------------------------------------- go.mod | 9 +++- go.sum | 26 +++++++++++- main.go | 42 ++++++++++++------- main_test.go | 106 ++++++++++++++++++++--------------------------- tokens.go | 81 ------------------------------------ tui_adapter.go | 39 ++++++++++------- 13 files changed, 191 insertions(+), 360 deletions(-) delete mode 100644 filelock.go delete mode 100644 filelock_test.go diff --git a/browser_flow.go b/browser_flow.go index 0b6d30c..82d0334 100644 --- a/browser_flow.go +++ b/browser_flow.go @@ -13,6 +13,7 @@ import ( retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/cli/tui" + "github.com/go-authgate/sdk-go/tokenstore" ) // buildAuthURL constructs the /oauth/authorize URL with all required parameters. @@ -29,7 +30,7 @@ func buildAuthURL(state string, pkce *PKCEParams) string { } // exchangeCode exchanges an authorization code for access + refresh tokens. -func exchangeCode(ctx context.Context, code, codeVerifier string) (*TokenStorage, error) { +func exchangeCode(ctx context.Context, code, codeVerifier string) (*tokenstore.Token, error) { ctx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout) defer cancel() @@ -84,7 +85,7 @@ func exchangeCode(ctx context.Context, code, codeVerifier string) (*TokenStorage return nil, fmt.Errorf("invalid token response: %w", err) } - return &TokenStorage{ + return &tokenstore.Token{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, TokenType: tokenResp.TokenType, @@ -199,7 +200,7 @@ func performBrowserFlowWithUpdates( }() storage, err := startCallbackServer(ctx, callbackPort, state, - func(callbackCtx context.Context, code string) (*TokenStorage, error) { + func(callbackCtx context.Context, code string) (*tokenstore.Token, error) { updates <- tui.FlowUpdate{ Type: tui.StepStart, Step: 3, @@ -232,9 +233,8 @@ func performBrowserFlowWithUpdates( } updates <- tui.FlowUpdate{Type: tui.CallbackReceived} - storage.Flow = "browser" - if err := saveTokens(storage); err != nil { + if err := tokenStore.Save(storage); err != nil { updates <- tui.FlowUpdate{ Type: tui.StepError, Message: fmt.Sprintf("Warning: Failed to save tokens: %v", err), @@ -247,5 +247,5 @@ func performBrowserFlowWithUpdates( TotalSteps: 3, } - return toTUITokenStorage(storage), true, nil + return toTUITokenStorage(storage, "browser"), true, nil } diff --git a/callback.go b/callback.go index a887bd6..c55d2a8 100644 --- a/callback.go +++ b/callback.go @@ -9,6 +9,8 @@ import ( "net/http" "sync" "time" + + "github.com/go-authgate/sdk-go/tokenstore" ) const ( @@ -23,7 +25,7 @@ var ErrCallbackTimeout = errors.New("browser authorization timed out") // callbackResult holds the outcome of the local callback round-trip. type callbackResult struct { - Storage *TokenStorage + Storage *tokenstore.Token Error string Desc string // Detailed description (for terminal only) SanitizedMsg string // User-friendly message (for browser only) @@ -32,12 +34,12 @@ type callbackResult struct { // startCallbackServer starts a local HTTP server on the given port and waits // for the OAuth callback. It validates the returned state against expectedState, // calls exchangeFn to exchange the code for tokens, and returns the resulting -// TokenStorage (or an error). +// token (or an error). // // The server shuts itself down after the first request. func startCallbackServer(ctx context.Context, port int, expectedState string, - exchangeFn func(context.Context, string) (*TokenStorage, error), -) (*TokenStorage, error) { + exchangeFn func(context.Context, string) (*tokenstore.Token, error), +) (*tokenstore.Token, error) { resultCh := make(chan callbackResult, 1) var once sync.Once diff --git a/callback_test.go b/callback_test.go index 0d4d56f..89e575b 100644 --- a/callback_test.go +++ b/callback_test.go @@ -9,10 +9,12 @@ import ( "strings" "testing" "time" + + "github.com/go-authgate/sdk-go/tokenstore" ) type callbackServerResult struct { - storage *TokenStorage + storage *tokenstore.Token err error } @@ -21,7 +23,7 @@ type callbackServerResult struct { func startCallbackServerAsync( t *testing.T, ctx context.Context, //nolint:revive // t before ctx in test helpers port int, state string, - exchangeFn func(context.Context, string) (*TokenStorage, error), + exchangeFn func(context.Context, string) (*tokenstore.Token, error), ) chan callbackServerResult { t.Helper() ch := make(chan callbackServerResult, 1) @@ -35,22 +37,22 @@ func startCallbackServerAsync( } // noExchangeFn returns an exchange function that fails the test if called. -func noExchangeFn(t *testing.T) func(context.Context, string) (*TokenStorage, error) { +func noExchangeFn(t *testing.T) func(context.Context, string) (*tokenstore.Token, error) { t.Helper() - return func(_ context.Context, _ string) (*TokenStorage, error) { + return func(_ context.Context, _ string) (*tokenstore.Token, error) { t.Error("exchangeFn should not be called") return nil, errors.New("should not be called") } } // stubExchangeFn returns an exchange function that validates the received code -// and returns a minimal TokenStorage on success. -func stubExchangeFn(wantCode string) func(context.Context, string) (*TokenStorage, error) { - return func(_ context.Context, gotCode string) (*TokenStorage, error) { +// and returns a minimal token on success. +func stubExchangeFn(wantCode string) func(context.Context, string) (*tokenstore.Token, error) { + return func(_ context.Context, gotCode string) (*tokenstore.Token, error) { if gotCode != wantCode { return nil, fmt.Errorf("unexpected code: got %q, want %q", gotCode, wantCode) } - return &TokenStorage{AccessToken: "test-token"}, nil + return &tokenstore.Token{AccessToken: "test-token"}, nil } } @@ -197,7 +199,7 @@ func TestCallbackServer_ExchangeFailure(t *testing.T) { state := "state-for-exchange-failure" ch := startCallbackServerAsync(t, context.Background(), port, state, - func(_ context.Context, _ string) (*TokenStorage, error) { + func(_ context.Context, _ string) (*tokenstore.Token, error) { return nil, errors.New("unauthorized_client: backend service authentication failed") }) diff --git a/config.go b/config.go index f027d80..f2fcc97 100644 --- a/config.go +++ b/config.go @@ -13,6 +13,8 @@ import ( "strings" "time" + "github.com/go-authgate/sdk-go/tokenstore" + retry "github.com/appleboy/go-httpretry" "github.com/google/uuid" "github.com/joho/godotenv" @@ -29,9 +31,11 @@ var ( callbackPort int scope string tokenFile string + tokenStoreMode string forceDevice bool configInitialized bool retryClient *retry.Client + tokenStore tokenstore.Store flagServerURL *string flagClientID *string @@ -40,6 +44,7 @@ var ( flagCallbackPort *int flagScope *string flagTokenFile *string + flagTokenStore *string flagDevice *bool flagVersion *bool ) @@ -50,6 +55,7 @@ const ( refreshTokenTimeout = 10 * time.Second deviceCodeRequestTimeout = 10 * time.Second maxResponseBodySize = 1 * 1024 * 1024 // 1 MB — guards against oversized server responses + defaultKeyringService = "authgate-cli" ) func init() { @@ -82,6 +88,11 @@ func init() { "", "Token storage file (default: .authgate-tokens.json or TOKEN_FILE env)", ) + flagTokenStore = flag.String( + "token-store", + "", + "Token storage backend: auto, file, keyring (default: auto or TOKEN_STORE env)", + ) flagDevice = flag.Bool( "device", false, @@ -176,6 +187,32 @@ func initConfig() { if err != nil { panic(fmt.Sprintf("failed to create retry client: %v", err)) } + + // Initialize token store based on mode + tokenStoreMode = getConfig(*flagTokenStore, "TOKEN_STORE", "auto") + fileStore := tokenstore.NewFileStore(tokenFile) + switch tokenStoreMode { + case "file": + tokenStore = fileStore + case "keyring": + tokenStore = tokenstore.NewKeyringStore(defaultKeyringService) + case "auto": + kr := tokenstore.NewKeyringStore(defaultKeyringService) + tokenStore = tokenstore.NewSecureStore(kr, fileStore) + if !tokenStore.(*tokenstore.SecureStore).UseKeyring() { + fmt.Fprintln( + os.Stderr, + "WARNING: OS keyring unavailable, falling back to file-based token storage", + ) + } + default: + fmt.Fprintf( + os.Stderr, + "Error: Invalid token-store value: %s (must be auto, file, or keyring)\n", + tokenStoreMode, + ) + os.Exit(1) + } } func getConfig(flagValue, envKey, defaultValue string) string { diff --git a/device_flow.go b/device_flow.go index ccc7c4e..bf5c0e7 100644 --- a/device_flow.go +++ b/device_flow.go @@ -13,6 +13,7 @@ import ( retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/cli/tui" + "github.com/go-authgate/sdk-go/tokenstore" "golang.org/x/oauth2" ) @@ -289,23 +290,22 @@ func performDeviceFlowWithUpdates( TotalSteps: 2, } - storage := &TokenStorage{ + storage := &tokenstore.Token{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, TokenType: token.Type(), ExpiresAt: token.Expiry, ClientID: clientID, - Flow: "device", } - if err := saveTokens(storage); err != nil { + if err := tokenStore.Save(storage); err != nil { updates <- tui.FlowUpdate{ Type: tui.StepError, Message: fmt.Sprintf("Warning: Failed to save tokens: %v", err), } } - return toTUITokenStorage(storage), nil + return toTUITokenStorage(storage, "device"), nil } // pollForTokenWithUpdates polls for a token while sending progress updates. diff --git a/filelock.go b/filelock.go deleted file mode 100644 index 724a5c3..0000000 --- a/filelock.go +++ /dev/null @@ -1,68 +0,0 @@ -package main - -import ( - "fmt" - "os" - "time" -) - -const ( - lockMaxRetries = 50 - lockRetryDelay = 100 * time.Millisecond - staleLockTimeout = 30 * time.Second -) - -// fileLock represents a file lock. -type fileLock struct { - lockFile *os.File - lockPath string -} - -// acquireFileLock acquires an exclusive lock on the token file. -// Uses a separate lock file to coordinate access across processes. -func acquireFileLock(filePath string) (*fileLock, error) { - lockPath := filePath + ".lock" - - for range lockMaxRetries { - lockFile, err := os.OpenFile(lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) - if err == nil { - fmt.Fprintf(lockFile, "%d", os.Getpid()) - return &fileLock{ - lockFile: lockFile, - lockPath: lockPath, - }, nil - } - - if os.IsExist(err) { - if info, statErr := os.Stat(lockPath); statErr == nil { - if time.Since(info.ModTime()) > staleLockTimeout { - if remErr := os.Remove(lockPath); remErr != nil && !os.IsNotExist(remErr) { - return nil, fmt.Errorf( - "failed to remove stale lock file %s: %w", - lockPath, - remErr, - ) - } - continue - } - } - time.Sleep(lockRetryDelay) - continue - } - - return nil, fmt.Errorf("failed to acquire file lock: %w", err) - } - - return nil, fmt.Errorf( - "timeout waiting for file lock after %v", - time.Duration(lockMaxRetries)*lockRetryDelay, - ) -} - -// release releases the file lock. -func (fl *fileLock) release() error { - if fl.lockFile != nil { - fl.lockFile.Close() - } - return os.Remove(fl.lockPath) -} diff --git a/filelock_test.go b/filelock_test.go deleted file mode 100644 index 3305c82..0000000 --- a/filelock_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package main - -import ( - "os" - "path/filepath" - "sync" - "testing" - "time" -) - -func TestAcquireAndRelease(t *testing.T) { - dir := t.TempDir() - target := filepath.Join(dir, "tokens.json") - - lock, err := acquireFileLock(target) - if err != nil { - t.Fatalf("acquireFileLock() error: %v", err) - } - - lockPath := target + ".lock" - if _, err := os.Stat(lockPath); os.IsNotExist(err) { - t.Error("lock file was not created") - } - - if err := lock.release(); err != nil { - t.Errorf("release() error: %v", err) - } - - if _, err := os.Stat(lockPath); !os.IsNotExist(err) { - t.Error("lock file was not removed after release") - } -} - -func TestConcurrentLocks(t *testing.T) { - dir := t.TempDir() - target := filepath.Join(dir, "tokens.json") - - const goroutines = 10 - var wg sync.WaitGroup - var mu sync.Mutex - concurrent := 0 - - for i := range goroutines { - wg.Add(1) - go func(idx int) { - defer wg.Done() - - lock, err := acquireFileLock(target) - if err != nil { - t.Errorf("goroutine %d: acquireFileLock() error: %v", idx, err) - return - } - - mu.Lock() - concurrent++ - if concurrent > 1 { - t.Errorf("goroutine %d: more than one lock holder at a time: %d", idx, concurrent) - } - mu.Unlock() - - mu.Lock() - concurrent-- - mu.Unlock() - - _ = lock.release() - }(i) - } - - wg.Wait() -} - -func TestStaleLockRemoval(t *testing.T) { - dir := t.TempDir() - target := filepath.Join(dir, "tokens.json") - lockPath := target + ".lock" - - f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_WRONLY, 0o600) - if err != nil { - t.Fatal(err) - } - f.Close() - - staleTime := time.Now().Add(-60 * time.Second) - if err := os.Chtimes(lockPath, staleTime, staleTime); err != nil { - t.Fatalf("os.Chtimes: %v", err) - } - - lock, err := acquireFileLock(target) - if err != nil { - t.Fatalf("acquireFileLock() with stale lock: %v", err) - } - _ = lock.release() -} diff --git a/go.mod b/go.mod index af8cc94..88a1cde 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,13 @@ module github.com/go-authgate/cli -go 1.24.2 +go 1.25.0 require ( charm.land/bubbles/v2 v2.0.0 charm.land/bubbletea/v2 v2.0.0 charm.land/lipgloss/v2 v2.0.0 github.com/appleboy/go-httpretry v0.11.0 + github.com/go-authgate/sdk-go v0.0.0-20260308153218-5d41f853a425 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/mattn/go-isatty v0.0.20 @@ -15,6 +16,7 @@ require ( ) require ( + al.essio.dev/pkg/shellescape v1.6.0 // indirect github.com/charmbracelet/colorprofile v0.4.2 // indirect github.com/charmbracelet/harmonica v0.2.0 // indirect github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 // indirect @@ -24,11 +26,14 @@ require ( github.com/charmbracelet/x/windows v0.2.2 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect + github.com/danieljoos/wincred v1.2.3 // indirect + github.com/godbus/dbus/v5 v5.2.2 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-runewidth v0.0.20 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/zalando/go-keyring v0.2.6 // indirect golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.41.0 // indirect + golang.org/x/sys v0.42.0 // indirect ) diff --git a/go.sum b/go.sum index 7636236..2b3cb7e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +al.essio.dev/pkg/shellescape v1.6.0 h1:NxFcEqzFSEVCGN2yq7Huv/9hyCEGVa/TncnOOBBeXHA= +al.essio.dev/pkg/shellescape v1.6.0/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s= charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI= charm.land/bubbletea/v2 v2.0.0 h1:p0d6CtWyJXJ9GfzMpUUqbP/XUUhhlk06+vCKWmox1wQ= @@ -28,6 +30,16 @@ github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSE github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0= github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk= github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ= +github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-authgate/sdk-go v0.0.0-20260308153218-5d41f853a425 h1:MpcgC9DmMMdnLAF1Vlr3dax0VBmB2uLeDHIqBPMCnxs= +github.com/go-authgate/sdk-go v0.0.0-20260308153218-5d41f853a425/go.mod h1:RGqvrFdrPnOumndoQQV8qzu8zP1KFUZPdhX0IkWduho= +github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= +github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -40,10 +52,18 @@ github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjc github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= +github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= @@ -51,7 +71,9 @@ golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwE golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 58b30a4..6f271d0 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -13,6 +14,8 @@ import ( "syscall" "time" + "github.com/go-authgate/sdk-go/tokenstore" + retry "github.com/appleboy/go-httpretry" "github.com/go-authgate/cli/tui" ) @@ -36,10 +39,14 @@ func run(ctx context.Context, ui tui.Manager) int { } ui.ShowHeader(clientMode, serverURL, clientID) - var storage *TokenStorage + var storage *tokenstore.Token // Try to reuse or refresh existing tokens. - existing, err := loadTokens() + existing, err := tokenStore.Load(clientID) + if err != nil && !errors.Is(err, tokenstore.ErrNotFound) { + fmt.Fprintf(os.Stderr, "Error: Failed to load tokens: %v\n", err) + return 1 + } if err == nil && existing != nil { ui.ShowExistingTokens() if time.Now().Before(existing.ExpiresAt) { @@ -60,8 +67,9 @@ func run(ctx context.Context, ui tui.Manager) int { } // No valid tokens — select and run the appropriate flow. + var flow string if storage == nil { - storage, err = authenticate(ctx, ui) + storage, flow, err = authenticate(ctx, ui) if err != nil { // Error details are already displayed by the UI manager // Just exit with error code @@ -70,7 +78,7 @@ func run(ctx context.Context, ui tui.Manager) int { } // Display token info. - ui.ShowTokenInfo(toTUITokenStorage(storage)) + ui.ShowTokenInfo(toTUITokenStorage(storage, flow)) // Verify token against server. info, err := verifyToken(ctx, storage.AccessToken) @@ -93,7 +101,7 @@ func run(ctx context.Context, ui tui.Manager) int { if err := makeAPICallWithAutoRefresh(ctx, storage, ui); err != nil { if err == ErrRefreshTokenExpired { ui.ShowRefreshTokenExpired() - storage, err = authenticate(ctx, ui) + storage, _, err = authenticate(ctx, ui) if err != nil { fmt.Fprintf(os.Stderr, "Re-authentication failed: %v\n", err) return 1 @@ -116,39 +124,39 @@ func run(ctx context.Context, ui tui.Manager) int { // 2. Environment signals (SSH, no display, port busy) → Device Code Flow // 3. Browser available → Authorization Code Flow with PKCE // - openBrowser() error → immediate fallback to Device Code Flow -func authenticate(ctx context.Context, ui tui.Manager) (*TokenStorage, error) { +func authenticate(ctx context.Context, ui tui.Manager) (*tokenstore.Token, string, error) { if forceDevice { ui.ShowFlowSelection("Device Code Flow (forced via flag)") tuiStorage, err := ui.RunDeviceFlow(ctx, performDeviceFlowWithUpdates) - return fromTUITokenStorage(tuiStorage), err + return fromTUITokenStorage(tuiStorage), flowFromTUI(tuiStorage), err } avail := checkBrowserAvailability(ctx, callbackPort) if !avail.Available { ui.ShowFlowSelection(fmt.Sprintf("Device Code Flow (%s)", avail.Reason)) tuiStorage, err := ui.RunDeviceFlow(ctx, performDeviceFlowWithUpdates) - return fromTUITokenStorage(tuiStorage), err + return fromTUITokenStorage(tuiStorage), flowFromTUI(tuiStorage), err } ui.ShowFlowSelection("Authorization Code Flow (browser)") tuiStorage, ok, err := ui.RunBrowserFlow(ctx, performBrowserFlowWithUpdates) if err != nil { - return nil, err + return nil, "", err } if !ok { // openBrowser() failed; fall back to Device Code Flow immediately. ui.ShowFlowSelection("Device Code Flow (browser unavailable)") tuiStorage, err := ui.RunDeviceFlow(ctx, performDeviceFlowWithUpdates) - return fromTUITokenStorage(tuiStorage), err + return fromTUITokenStorage(tuiStorage), flowFromTUI(tuiStorage), err } - return fromTUITokenStorage(tuiStorage), nil + return fromTUITokenStorage(tuiStorage), flowFromTUI(tuiStorage), nil } // ----------------------------------------------------------------------- // Token refresh // ----------------------------------------------------------------------- -func refreshAccessToken(ctx context.Context, refreshToken string) (*TokenStorage, error) { +func refreshAccessToken(ctx context.Context, refreshToken string) (*tokenstore.Token, error) { ctx, cancel := context.WithTimeout(ctx, refreshTokenTimeout) defer cancel() @@ -203,7 +211,7 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*TokenStorage newRefreshToken = refreshToken } - storage := &TokenStorage{ + storage := &tokenstore.Token{ AccessToken: tokenResp.AccessToken, RefreshToken: newRefreshToken, TokenType: tokenResp.TokenType, @@ -211,7 +219,7 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*TokenStorage ClientID: clientID, } - if err := saveTokens(storage); err != nil { + if err := tokenStore.Save(storage); err != nil { fmt.Printf("Warning: Failed to save refreshed tokens: %v\n", err) } return storage, nil @@ -250,7 +258,11 @@ func verifyToken(ctx context.Context, accessToken string) (string, error) { } // makeAPICallWithAutoRefresh demonstrates the 401 → refresh → retry pattern. -func makeAPICallWithAutoRefresh(ctx context.Context, storage *TokenStorage, ui tui.Manager) error { +func makeAPICallWithAutoRefresh( + ctx context.Context, + storage *tokenstore.Token, + ui tui.Manager, +) error { resp, err := retryClient.Get(ctx, serverURL+"/oauth/tokeninfo", retry.WithHeader("Authorization", "Bearer "+storage.AccessToken), ) diff --git a/main_test.go b/main_test.go index 222d9fa..92eb59c 100644 --- a/main_test.go +++ b/main_test.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "path/filepath" "sync" "sync/atomic" @@ -14,6 +13,7 @@ import ( "time" retry "github.com/appleboy/go-httpretry" + "github.com/go-authgate/sdk-go/tokenstore" ) func init() { @@ -37,6 +37,9 @@ func init() { panic(fmt.Sprintf("failed to create retry client: %v", err)) } } + if tokenStore == nil { + tokenStore = tokenstore.NewFileStore(tokenFile) + } } // ----------------------------------------------------------------------- @@ -133,92 +136,79 @@ func TestValidateTokenResponse(t *testing.T) { // ----------------------------------------------------------------------- func TestSaveAndLoadTokens(t *testing.T) { - tmpFile, err := os.CreateTemp(t.TempDir(), "tokens-*.json") - if err != nil { - t.Fatal(err) - } - tmpFile.Close() - - origTokenFile := tokenFile + origTokenStore := tokenStore origClientID := clientID t.Cleanup(func() { - tokenFile = origTokenFile + tokenStore = origTokenStore clientID = origClientID }) - tokenFile = tmpFile.Name() + tokenStore = tokenstore.NewFileStore(filepath.Join(t.TempDir(), "tokens.json")) clientID = "test-client-id" - storage := &TokenStorage{ + storage := &tokenstore.Token{ AccessToken: "access-token-value", RefreshToken: "refresh-token-value", TokenType: "Bearer", ExpiresAt: time.Now().Add(time.Hour).UTC().Truncate(time.Second), ClientID: clientID, - Flow: "browser", } - if err := saveTokens(storage); err != nil { - t.Fatalf("saveTokens() error: %v", err) + if err := tokenStore.Save(storage); err != nil { + t.Fatalf("Save() error: %v", err) } - loaded, err := loadTokens() + loaded, err := tokenStore.Load(clientID) if err != nil { - t.Fatalf("loadTokens() error: %v", err) + t.Fatalf("Load() error: %v", err) } if loaded.AccessToken != storage.AccessToken { t.Errorf("AccessToken mismatch: got %q, want %q", loaded.AccessToken, storage.AccessToken) } - if loaded.Flow != storage.Flow { - t.Errorf("Flow mismatch: got %q, want %q", loaded.Flow, storage.Flow) + if loaded.RefreshToken != storage.RefreshToken { + t.Errorf( + "RefreshToken mismatch: got %q, want %q", + loaded.RefreshToken, + storage.RefreshToken, + ) } } func TestSaveTokens_MultipleClients(t *testing.T) { - tmpFile, err := os.CreateTemp(t.TempDir(), "tokens-multi-*.json") - if err != nil { - t.Fatal(err) - } - tmpFile.Close() - - origTokenFile := tokenFile - origClientID := clientID - t.Cleanup(func() { - tokenFile = origTokenFile - clientID = origClientID - }) + origTokenStore := tokenStore + t.Cleanup(func() { tokenStore = origTokenStore }) - tokenFile = tmpFile.Name() + tokenStore = tokenstore.NewFileStore(filepath.Join(t.TempDir(), "tokens.json")) for _, id := range []string{"client-a", "client-b"} { - clientID = id - if err := saveTokens(&TokenStorage{ + if err := tokenStore.Save(&tokenstore.Token{ AccessToken: "token-" + id, RefreshToken: "refresh-" + id, TokenType: "Bearer", ExpiresAt: time.Now().Add(time.Hour), ClientID: id, }); err != nil { - t.Fatalf("saveTokens(%s) error: %v", id, err) + t.Fatalf("Save(%s) error: %v", id, err) } } - data, _ := os.ReadFile(tokenFile) - var sm TokenStorageMap - if err := json.Unmarshal(data, &sm); err != nil { - t.Fatalf("unmarshal error: %v", err) - } - if len(sm.Tokens) != 2 { - t.Errorf("expected 2 tokens, got %d", len(sm.Tokens)) + for _, id := range []string{"client-a", "client-b"} { + loaded, err := tokenStore.Load(id) + if err != nil { + t.Fatalf("Load(%s) error: %v", id, err) + } + if loaded.AccessToken != "token-"+id { + t.Errorf("Load(%s): AccessToken = %q, want %q", id, loaded.AccessToken, "token-"+id) + } } } func TestSaveTokens_ConcurrentWrites(t *testing.T) { - tempDir := t.TempDir() - origTokenFile := tokenFile - t.Cleanup(func() { tokenFile = origTokenFile }) - tokenFile = filepath.Join(tempDir, "tokens.json") + origTokenStore := tokenStore + t.Cleanup(func() { tokenStore = origTokenStore }) + + tokenStore = tokenstore.NewFileStore(filepath.Join(t.TempDir(), "tokens.json")) const goroutines = 10 var wg sync.WaitGroup @@ -226,29 +216,24 @@ func TestSaveTokens_ConcurrentWrites(t *testing.T) { for i := range goroutines { go func(id int) { defer wg.Done() - if err := saveTokens(&TokenStorage{ + if err := tokenStore.Save(&tokenstore.Token{ AccessToken: fmt.Sprintf("access-token-%d", id), RefreshToken: fmt.Sprintf("refresh-token-%d", id), TokenType: "Bearer", ExpiresAt: time.Now().Add(time.Hour), ClientID: fmt.Sprintf("client-%d", id), }); err != nil { - t.Errorf("goroutine %d: saveTokens() error: %v", id, err) + t.Errorf("goroutine %d: Save() error: %v", id, err) } }(i) } wg.Wait() - data, err := os.ReadFile(tokenFile) - if err != nil { - t.Fatalf("failed to read token file: %v", err) - } - var sm TokenStorageMap - if err := json.Unmarshal(data, &sm); err != nil { - t.Fatalf("failed to parse token file: %v", err) - } - if len(sm.Tokens) != goroutines { - t.Errorf("expected %d tokens, got %d", goroutines, len(sm.Tokens)) + for i := range goroutines { + id := fmt.Sprintf("client-%d", i) + if _, err := tokenStore.Load(id); err != nil { + t.Errorf("Load(%s) error: %v", id, err) + } } } @@ -304,15 +289,14 @@ func TestBuildAuthURL_ContainsRequiredParams(t *testing.T) { func TestRefreshAccessToken_RotationMode(t *testing.T) { origServerURL := serverURL origClientID := clientID - origTokenFile := tokenFile + origTokenStore := tokenStore t.Cleanup(func() { serverURL = origServerURL clientID = origClientID - tokenFile = origTokenFile + tokenStore = origTokenStore }) - tempDir := t.TempDir() - tokenFile = filepath.Join(tempDir, "tokens.json") + tokenStore = tokenstore.NewFileStore(filepath.Join(t.TempDir(), "tokens.json")) clientID = "test-client-rotation" tests := []struct { diff --git a/tokens.go b/tokens.go index 03d600c..3ef1198 100644 --- a/tokens.go +++ b/tokens.go @@ -1,11 +1,8 @@ package main import ( - "encoding/json" "errors" "fmt" - "os" - "time" ) // ErrorResponse is an OAuth error payload. @@ -17,84 +14,6 @@ type ErrorResponse struct { // ErrRefreshTokenExpired indicates the refresh token has expired or is invalid. var ErrRefreshTokenExpired = errors.New("refresh token expired or invalid") -// TokenStorage holds persisted OAuth tokens for one client. -type TokenStorage struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - TokenType string `json:"token_type"` - ExpiresAt time.Time `json:"expires_at"` - ClientID string `json:"client_id"` - Flow string `json:"flow,omitempty"` // "browser" or "device" -} - -// TokenStorageMap manages tokens for multiple clients in one file. -type TokenStorageMap struct { - Tokens map[string]*TokenStorage `json:"tokens"` -} - -func loadTokens() (*TokenStorage, error) { - data, err := os.ReadFile(tokenFile) - if err != nil { - return nil, err - } - var storageMap TokenStorageMap - if err := json.Unmarshal(data, &storageMap); err != nil { - return nil, fmt.Errorf("failed to parse token file: %w", err) - } - if storageMap.Tokens == nil { - return nil, errors.New("no tokens in file") - } - if storage, ok := storageMap.Tokens[clientID]; ok { - return storage, nil - } - return nil, fmt.Errorf("no tokens found for client_id: %s", clientID) -} - -func saveTokens(storage *TokenStorage) error { - if storage.ClientID == "" { - storage.ClientID = clientID - } - - lock, err := acquireFileLock(tokenFile) - if err != nil { - return fmt.Errorf("failed to acquire lock: %w", err) - } - defer func() { _ = lock.release() }() - - var storageMap TokenStorageMap - if existing, err := os.ReadFile(tokenFile); err == nil { - if unmarshalErr := json.Unmarshal(existing, &storageMap); unmarshalErr != nil { - storageMap.Tokens = make(map[string]*TokenStorage) - } - } - if storageMap.Tokens == nil { - storageMap.Tokens = make(map[string]*TokenStorage) - } - - storageMap.Tokens[storage.ClientID] = storage - - data, err := json.MarshalIndent(storageMap, "", " ") - if err != nil { - return err - } - - tempFile := tokenFile + ".tmp" - if err := os.WriteFile(tempFile, data, 0o600); err != nil { - return fmt.Errorf("failed to write temp file: %w", err) - } - if err := os.Rename(tempFile, tokenFile); err != nil { - if removeErr := os.Remove(tempFile); removeErr != nil { - return fmt.Errorf( - "failed to rename temp file: %v; also failed to remove temp file: %w", - err, - removeErr, - ) - } - return fmt.Errorf("failed to rename temp file: %w", err) - } - return nil -} - // validateTokenResponse performs basic sanity checks on a token response. func validateTokenResponse(accessToken, tokenType string, expiresIn int) error { if accessToken == "" { diff --git a/tui_adapter.go b/tui_adapter.go index 76c6c77..214cd66 100644 --- a/tui_adapter.go +++ b/tui_adapter.go @@ -1,34 +1,43 @@ package main -import "github.com/go-authgate/cli/tui" +import ( + "github.com/go-authgate/cli/tui" + "github.com/go-authgate/sdk-go/tokenstore" +) -// toTUITokenStorage converts main.TokenStorage to tui.TokenStorage. -func toTUITokenStorage(storage *TokenStorage) *tui.TokenStorage { - if storage == nil { +// toTUITokenStorage converts tokenstore.Token to tui.TokenStorage. +func toTUITokenStorage(token *tokenstore.Token, flow string) *tui.TokenStorage { + if token == nil { return nil } return &tui.TokenStorage{ - AccessToken: storage.AccessToken, - RefreshToken: storage.RefreshToken, - TokenType: storage.TokenType, - ExpiresAt: storage.ExpiresAt, - ClientID: storage.ClientID, - Flow: storage.Flow, + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + TokenType: token.TokenType, + ExpiresAt: token.ExpiresAt, + ClientID: token.ClientID, + Flow: flow, } } -// fromTUITokenStorage converts tui.TokenStorage to main.TokenStorage. -func fromTUITokenStorage(tuiStorage *tui.TokenStorage) *TokenStorage { +// flowFromTUI extracts the Flow field from tui.TokenStorage (nil-safe). +func flowFromTUI(ts *tui.TokenStorage) string { + if ts == nil { + return "" + } + return ts.Flow +} + +// fromTUITokenStorage converts tui.TokenStorage to tokenstore.Token. +func fromTUITokenStorage(tuiStorage *tui.TokenStorage) *tokenstore.Token { if tuiStorage == nil { return nil } - - return &TokenStorage{ + return &tokenstore.Token{ AccessToken: tuiStorage.AccessToken, RefreshToken: tuiStorage.RefreshToken, TokenType: tuiStorage.TokenType, ExpiresAt: tuiStorage.ExpiresAt, ClientID: tuiStorage.ClientID, - Flow: tuiStorage.Flow, } } From 3c484c42262489e1ae6459f44c2c0ebf3f5803cf Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Mon, 9 Mar 2026 22:51:39 +0800 Subject: [PATCH 2/2] refactor(auth): extract token store helper and improve test coverage - Extract token store backend selection into testable newTokenStore() function - Set flow value to "cached" or "refreshed" when reusing existing tokens - Add unit tests for newTokenStore() covering all modes and error cases - Add unit tests for TUI adapter conversion functions Co-Authored-By: Claude Opus 4.6 --- config.go | 38 ++++++++++------- config_test.go | 80 ++++++++++++++++++++++++++++++++++ main.go | 7 +++ tui_adapter_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 211 insertions(+), 16 deletions(-) create mode 100644 config_test.go create mode 100644 tui_adapter_test.go diff --git a/config.go b/config.go index f2fcc97..cd9c68d 100644 --- a/config.go +++ b/config.go @@ -190,28 +190,34 @@ func initConfig() { // Initialize token store based on mode tokenStoreMode = getConfig(*flagTokenStore, "TOKEN_STORE", "auto") - fileStore := tokenstore.NewFileStore(tokenFile) - switch tokenStoreMode { - case "file": - tokenStore = fileStore - case "keyring": - tokenStore = tokenstore.NewKeyringStore(defaultKeyringService) - case "auto": - kr := tokenstore.NewKeyringStore(defaultKeyringService) - tokenStore = tokenstore.NewSecureStore(kr, fileStore) - if !tokenStore.(*tokenstore.SecureStore).UseKeyring() { + tokenStore, err = newTokenStore(tokenStoreMode, tokenFile, defaultKeyringService) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + if tokenStoreMode == "auto" { + if ss, ok := tokenStore.(*tokenstore.SecureStore); ok && !ss.UseKeyring() { fmt.Fprintln( os.Stderr, "WARNING: OS keyring unavailable, falling back to file-based token storage", ) } + } +} + +// newTokenStore creates a token store backend based on the given mode. +func newTokenStore(mode, tokenFilePath, keyringService string) (tokenstore.Store, error) { + fileStore := tokenstore.NewFileStore(tokenFilePath) + switch mode { + case "file": + return fileStore, nil + case "keyring": + return tokenstore.NewKeyringStore(keyringService), nil + case "auto": + kr := tokenstore.NewKeyringStore(keyringService) + return tokenstore.NewSecureStore(kr, fileStore), nil default: - fmt.Fprintf( - os.Stderr, - "Error: Invalid token-store value: %s (must be auto, file, or keyring)\n", - tokenStoreMode, - ) - os.Exit(1) + return nil, fmt.Errorf("invalid token store mode: %q (valid: auto, file, keyring)", mode) } } diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..0d36520 --- /dev/null +++ b/config_test.go @@ -0,0 +1,80 @@ +package main + +import ( + "testing" + + "github.com/go-authgate/sdk-go/tokenstore" +) + +func TestNewTokenStore(t *testing.T) { + tmpFile := t.TempDir() + "/tokens.json" + + tests := []struct { + name string + mode string + wantType string + wantErr bool + errSubstr string + }{ + { + name: "file mode returns FileStore", + mode: "file", + wantType: "*tokenstore.FileStore", + }, + { + name: "keyring mode returns KeyringStore", + mode: "keyring", + wantType: "*tokenstore.KeyringStore", + }, + { + name: "auto mode returns SecureStore", + mode: "auto", + wantType: "*tokenstore.SecureStore", + }, + { + name: "invalid mode returns error", + mode: "invalid", + wantErr: true, + errSubstr: "invalid token store mode", + }, + { + name: "empty mode returns error", + mode: "", + wantErr: true, + errSubstr: "invalid token store mode", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + store, err := newTokenStore(tc.mode, tmpFile, "test-service") + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if !containsSubstring(err.Error(), tc.errSubstr) { + t.Errorf("error %q should contain %q", err.Error(), tc.errSubstr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var gotType string + switch store.(type) { + case *tokenstore.FileStore: + gotType = "*tokenstore.FileStore" + case *tokenstore.KeyringStore: + gotType = "*tokenstore.KeyringStore" + case *tokenstore.SecureStore: + gotType = "*tokenstore.SecureStore" + default: + gotType = "unknown" + } + if gotType != tc.wantType { + t.Errorf("got type %s, want %s", gotType, tc.wantType) + } + }) + } +} diff --git a/main.go b/main.go index 6f271d0..55b9226 100644 --- a/main.go +++ b/main.go @@ -68,6 +68,13 @@ func run(ctx context.Context, ui tui.Manager) int { // No valid tokens — select and run the appropriate flow. var flow string + if storage != nil && err == nil && existing != nil { + if time.Now().Before(existing.ExpiresAt) { + flow = "cached" + } else { + flow = "refreshed" + } + } if storage == nil { storage, flow, err = authenticate(ctx, ui) if err != nil { diff --git a/tui_adapter_test.go b/tui_adapter_test.go new file mode 100644 index 0000000..62a9215 --- /dev/null +++ b/tui_adapter_test.go @@ -0,0 +1,102 @@ +package main + +import ( + "testing" + "time" + + "github.com/go-authgate/cli/tui" + "github.com/go-authgate/sdk-go/tokenstore" +) + +func TestToTUITokenStorage(t *testing.T) { + t.Run("nil token returns nil", func(t *testing.T) { + if got := toTUITokenStorage(nil, "test"); got != nil { + t.Errorf("expected nil, got %+v", got) + } + }) + + t.Run("converts all fields", func(t *testing.T) { + now := time.Now().Truncate(time.Second) + token := &tokenstore.Token{ + AccessToken: "access", + RefreshToken: "refresh", + TokenType: "Bearer", + ExpiresAt: now, + ClientID: "client-1", + } + + got := toTUITokenStorage(token, "browser") + + if got.AccessToken != "access" { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, "access") + } + if got.RefreshToken != "refresh" { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, "refresh") + } + if got.TokenType != "Bearer" { + t.Errorf("TokenType = %q, want %q", got.TokenType, "Bearer") + } + if !got.ExpiresAt.Equal(now) { + t.Errorf("ExpiresAt = %v, want %v", got.ExpiresAt, now) + } + if got.ClientID != "client-1" { + t.Errorf("ClientID = %q, want %q", got.ClientID, "client-1") + } + if got.Flow != "browser" { + t.Errorf("Flow = %q, want %q", got.Flow, "browser") + } + }) +} + +func TestFromTUITokenStorage(t *testing.T) { + t.Run("nil returns nil", func(t *testing.T) { + if got := fromTUITokenStorage(nil); got != nil { + t.Errorf("expected nil, got %+v", got) + } + }) + + t.Run("converts all fields", func(t *testing.T) { + now := time.Now().Truncate(time.Second) + tuiStorage := &tui.TokenStorage{ + AccessToken: "access", + RefreshToken: "refresh", + TokenType: "Bearer", + ExpiresAt: now, + ClientID: "client-1", + Flow: "device", + } + + got := fromTUITokenStorage(tuiStorage) + + if got.AccessToken != "access" { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, "access") + } + if got.RefreshToken != "refresh" { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, "refresh") + } + if got.TokenType != "Bearer" { + t.Errorf("TokenType = %q, want %q", got.TokenType, "Bearer") + } + if !got.ExpiresAt.Equal(now) { + t.Errorf("ExpiresAt = %v, want %v", got.ExpiresAt, now) + } + if got.ClientID != "client-1" { + t.Errorf("ClientID = %q, want %q", got.ClientID, "client-1") + } + }) +} + +func TestFlowFromTUI(t *testing.T) { + t.Run("nil returns empty", func(t *testing.T) { + if got := flowFromTUI(nil); got != "" { + t.Errorf("expected empty, got %q", got) + } + }) + + t.Run("returns flow field", func(t *testing.T) { + ts := &tui.TokenStorage{Flow: "device"} + if got := flowFromTUI(ts); got != "device" { + t.Errorf("got %q, want %q", got, "device") + } + }) +}