diff --git a/main.go b/main.go index a26b871..89d943c 100644 --- a/main.go +++ b/main.go @@ -28,21 +28,24 @@ import ( ) var ( - serverURL string - clientID string - tokenFile string - tokenStoreMode string - flagServerURL *string - flagClientID *string - flagTokenFile *string - flagTokenStore *string - configInitialized bool - retryClient *retry.Client - tokenStore credstore.Store[credstore.Token] + serverURL string + clientID string + tokenFile string + tokenStoreMode string + flagServerURL *string + flagClientID *string + flagTokenFile *string + flagTokenStore *string + configOnce sync.Once + retryClient *retry.Client + tokenStore credstore.Store[credstore.Token] ) const defaultKeyringService = "authgate-device-cli" +// maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS). +const maxResponseBodySize = 1 << 20 // 1 MB + // Timeout configuration for different operations const ( deviceCodeRequestTimeout = 10 * time.Second @@ -107,11 +110,12 @@ func init() { // initConfig parses flags and initializes configuration // Separated from init() to avoid conflicts with test flag parsing func initConfig() { - if configInitialized { - return - } - configInitialized = true + configOnce.Do(func() { + doInitConfig() + }) +} +func doInitConfig() { flag.Parse() // Priority: flag > env > default @@ -438,7 +442,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -638,7 +642,7 @@ func exchangeDeviceCode( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } @@ -696,7 +700,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return fmt.Errorf("failed to read response: %w", err) } @@ -746,7 +750,7 @@ func refreshAccessToken( } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return credstore.Token{}, fmt.Errorf("failed to read response: %w", err) } @@ -871,7 +875,7 @@ func makeAPICallWithAutoRefresh( defer resp.Body.Close() } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize)) if err != nil { return fmt.Errorf("failed to read response: %w", err) }