Skip to content

Commit d14fa1a

Browse files
authored
fix: lock OAuthConfig init (#412)
1 parent c004b01 commit d14fa1a

File tree

4 files changed

+156
-26
lines changed

4 files changed

+156
-26
lines changed

pkg/app/app_test.go

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package app
22

33
import (
4+
"crypto/rand"
5+
"crypto/rsa"
46
"encoding/base64"
7+
"encoding/json"
58
"errors"
69
"fmt"
710
"log"
@@ -14,15 +17,19 @@ import (
1417

1518
"github.com/golang/mock/gomock"
1619
zlog "github.com/rs/zerolog/log"
17-
"github.com/stretchr/testify/assert"
18-
1920
"github.com/snyk/go-application-framework/internal/api"
2021
"github.com/snyk/go-application-framework/internal/constants"
2122
"github.com/snyk/go-application-framework/internal/mocks"
23+
"github.com/snyk/go-application-framework/pkg/analytics"
2224
"github.com/snyk/go-application-framework/pkg/auth"
2325
"github.com/snyk/go-application-framework/pkg/configuration"
26+
localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
27+
pkgMocks "github.com/snyk/go-application-framework/pkg/mocks"
2428
"github.com/snyk/go-application-framework/pkg/runtimeinfo"
2529
"github.com/snyk/go-application-framework/pkg/workflow"
30+
"github.com/stretchr/testify/assert"
31+
"golang.org/x/oauth2"
32+
"golang.org/x/oauth2/jws"
2633
)
2734

2835
func Test_AddsDefaultFunctionForCustomConfigFiles(t *testing.T) {
@@ -589,3 +596,125 @@ func TestDefaultInputDirectory(t *testing.T) {
589596
assert.IsType(t, defaultFunction, defaultFunction)
590597
})
591598
}
599+
600+
func Test_auth_oauth(t *testing.T) {
601+
mockCtl := gomock.NewController(t)
602+
engine := CreateAppEngine()
603+
logger := engine.GetLogger()
604+
analytics := analytics.New()
605+
606+
t.Run("oauth token is set on global config", func(t *testing.T) {
607+
// Create separate configs for invocation and global
608+
globalConfig := engine.GetConfiguration()
609+
invocationConfig := globalConfig.Clone()
610+
611+
// Expected OAuth token that will be set after authentication
612+
expectedOAuthToken := "test-oauth-token-12345"
613+
614+
invocationConfig.Set(localworkflows.AuthTypeParameter, auth.AUTH_TYPE_OAUTH)
615+
616+
// Create mocks
617+
mockInvocationContext := pkgMocks.NewMockInvocationContext(mockCtl)
618+
mockInvocationContext.EXPECT().GetConfiguration().Return(invocationConfig).AnyTimes()
619+
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(logger).AnyTimes()
620+
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
621+
mockInvocationContext.EXPECT().GetEngine().Return(engine).AnyTimes()
622+
623+
mockAuthenticator := pkgMocks.NewMockAuthenticator(mockCtl)
624+
mockAuthenticator.EXPECT().Authenticate().DoAndReturn(func() error {
625+
// Simulate successful OAuth authentication by setting the token
626+
invocationConfig.Set(auth.CONFIG_KEY_OAUTH_TOKEN, expectedOAuthToken)
627+
return nil
628+
})
629+
630+
// Execute the auth workflow
631+
err := localworkflows.AuthEntryPointDI(mockInvocationContext, logger, engine, mockAuthenticator)
632+
assert.NoError(t, err)
633+
634+
// Verify that the OAuth token was set on the global config
635+
actualToken := globalConfig.Get(auth.CONFIG_KEY_OAUTH_TOKEN)
636+
assert.Equal(t, expectedOAuthToken, actualToken, "OAuth token should be set on global config after successful authentication")
637+
638+
// Verify that the authentication token is not set (should be cleared)
639+
assert.Empty(t, globalConfig.GetString(configuration.AUTHENTICATION_TOKEN), "Legacy authentication token should be cleared")
640+
})
641+
642+
t.Run("oauth token change updates api url extraction", func(t *testing.T) {
643+
createOAuthTokenWithAudience := func(audience string) string {
644+
header := &jws.Header{}
645+
claims := &jws.ClaimSet{
646+
Aud: audience,
647+
}
648+
pk, err := rsa.GenerateKey(rand.Reader, 2048)
649+
assert.NoError(t, err)
650+
651+
accessToken, err := jws.Encode(header, claims, pk)
652+
assert.NoError(t, err)
653+
654+
token := oauth2.Token{
655+
AccessToken: accessToken,
656+
}
657+
658+
tokenBytes, err := json.Marshal(token)
659+
assert.NoError(t, err)
660+
661+
return string(tokenBytes)
662+
}
663+
664+
globalConfig := engine.GetConfiguration()
665+
globalConfig.Set(configuration.API_URL, "https://api.snyk.io")
666+
invocationConfig := globalConfig.Clone()
667+
668+
firstAPIURL := "https://api.eu.snyk.io"
669+
firstOAuthToken := createOAuthTokenWithAudience(firstAPIURL)
670+
671+
// Set auth type to OAuth
672+
invocationConfig.Set(localworkflows.AuthTypeParameter, auth.AUTH_TYPE_OAUTH)
673+
674+
// Create mocks
675+
mockInvocationContext := pkgMocks.NewMockInvocationContext(mockCtl)
676+
mockInvocationContext.EXPECT().GetConfiguration().Return(invocationConfig).AnyTimes()
677+
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(logger).AnyTimes()
678+
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
679+
680+
// First authentication with US token
681+
mockAuthenticator := pkgMocks.NewMockAuthenticator(mockCtl)
682+
mockAuthenticator.EXPECT().Authenticate().DoAndReturn(func() error {
683+
invocationConfig.Set(auth.CONFIG_KEY_OAUTH_TOKEN, firstOAuthToken)
684+
return nil
685+
})
686+
687+
// Execute first auth workflow
688+
err := localworkflows.AuthEntryPointDI(mockInvocationContext, logger, engine, mockAuthenticator)
689+
assert.NoError(t, err)
690+
691+
// Verify first OAuth token was set
692+
actualToken := globalConfig.GetString(auth.CONFIG_KEY_OAUTH_TOKEN)
693+
assert.Equal(t, firstOAuthToken, actualToken)
694+
695+
actualApiUrl := globalConfig.GetString(configuration.API_URL)
696+
assert.Equal(t, firstAPIURL, actualApiUrl, "First OAuth token should contain EU API URL")
697+
698+
// Now simulate re-authentication with EU
699+
secondAPIURL := "https://api.us.snyk.io"
700+
secondOAuthToken := createOAuthTokenWithAudience(secondAPIURL)
701+
702+
// Mock second authentication
703+
mockAuthenticator.EXPECT().Authenticate().DoAndReturn(func() error {
704+
invocationConfig.Set(auth.CONFIG_KEY_OAUTH_TOKEN, secondOAuthToken)
705+
return nil
706+
})
707+
708+
// Execute second workflow
709+
err = localworkflows.AuthEntryPointDI(mockInvocationContext, logger, engine, mockAuthenticator)
710+
assert.NoError(t, err)
711+
712+
// Verify second OAuth token was set
713+
actualToken = globalConfig.GetString(auth.CONFIG_KEY_OAUTH_TOKEN)
714+
assert.Equal(t, secondOAuthToken, actualToken)
715+
716+
// Extract and verify second API URL from token
717+
actualApiUrl = globalConfig.GetString(configuration.API_URL)
718+
assert.Equal(t, secondAPIURL, actualApiUrl, "Second OAuth token should contain US API URL")
719+
})
720+
}

pkg/auth/oauth2authenticator.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,9 @@ func (o *oAuth2Authenticator) Authenticate() error {
273273
func (o *oAuth2Authenticator) CancelableAuthenticate(ctx context.Context) error {
274274
var err error
275275

276+
globalRefreshMutex.Lock()
276277
o.oauthConfig = getOAuthConfiguration(o.config)
278+
globalRefreshMutex.Unlock()
277279

278280
if o.grantType == ClientCredentialsGrant {
279281
err = o.authenticateWithClientCredentialsGrant(ctx)

pkg/local_workflows/auth_workflow.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
const (
1919
workflowNameAuth = "auth"
2020
headlessFlag = "headless"
21-
authTypeParameter = "auth-type"
21+
AuthTypeParameter = "auth-type"
2222
)
2323

2424
var authTypeDescription = fmt.Sprint("Authentication type (", auth.AUTH_TYPE_TOKEN, ", ", auth.AUTH_TYPE_OAUTH, ")")
@@ -39,7 +39,7 @@ var ConfigurationNewAuthenticationToken = "internal_new_snyk_token"
3939
// InitAuth initializes the auth workflow before registering it with the engine.
4040
func InitAuth(engine workflow.Engine) error {
4141
config := pflag.NewFlagSet(workflowNameAuth, pflag.ExitOnError)
42-
config.String(authTypeParameter, "", authTypeDescription)
42+
config.String(AuthTypeParameter, "", authTypeDescription)
4343
config.Bool(headlessFlag, false, "Enable headless OAuth authentication")
4444
config.String(auth.PARAMETER_CLIENT_SECRET, "", "Client Credential Grant, client secret")
4545
config.String(auth.PARAMETER_CLIENT_ID, "", "Client Credential Grant, client id")
@@ -59,7 +59,7 @@ func authEntryPoint(invocationCtx workflow.InvocationContext, _ []workflow.Data)
5959
config := invocationCtx.GetConfiguration()
6060
logger := invocationCtx.GetEnhancedLogger()
6161
engine := invocationCtx.GetEngine()
62-
globalConfig := invocationCtx.GetEngine().GetConfiguration()
62+
globalConfig := engine.GetConfiguration()
6363

6464
// cache always interferes with auth
6565
globalConfig.ClearCache()
@@ -76,7 +76,7 @@ func authEntryPoint(invocationCtx workflow.InvocationContext, _ []workflow.Data)
7676
auth.WithLogger(logger),
7777
)
7878

79-
err = entryPointDI(invocationCtx, logger, engine, authenticator)
79+
err = AuthEntryPointDI(invocationCtx, logger, engine, authenticator)
8080
return nil, err
8181
}
8282

@@ -102,18 +102,18 @@ func autoDetectAuthType(config configuration.Configuration) string {
102102
return auth.AUTH_TYPE_OAUTH
103103
}
104104

105-
func entryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logger, engine workflow.Engine, authenticator auth.Authenticator) (err error) {
105+
func AuthEntryPointDI(invocationCtx workflow.InvocationContext, logger *zerolog.Logger, engine workflow.Engine, authenticator auth.Authenticator) (err error) {
106106
analytics := invocationCtx.GetAnalytics()
107107
globalConfig := engine.GetConfiguration()
108108
config := invocationCtx.GetConfiguration()
109109

110-
authType := config.GetString(authTypeParameter)
110+
authType := config.GetString(AuthTypeParameter)
111111
if len(authType) == 0 {
112112
authType = autoDetectAuthType(config)
113113
}
114114

115115
logger.Printf("Authentication Type: %s", authType)
116-
analytics.AddExtensionStringValue(authTypeParameter, authType)
116+
analytics.AddExtensionStringValue(AuthTypeParameter, authType)
117117

118118
existingSnykToken := config.GetString(configuration.AUTHENTICATION_TOKEN)
119119
// always attempt to clear existing tokens before triggering auth for current config clone and global config

pkg/local_workflows/auth_workflow_test.go

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ import (
77

88
"github.com/golang/mock/gomock"
99
"github.com/rs/zerolog"
10-
"github.com/stretchr/testify/assert"
11-
1210
"github.com/snyk/go-application-framework/pkg/analytics"
1311
"github.com/snyk/go-application-framework/pkg/auth"
1412
"github.com/snyk/go-application-framework/pkg/configuration"
1513
"github.com/snyk/go-application-framework/pkg/mocks"
14+
"github.com/stretchr/testify/assert"
1615
)
1716

1817
func Test_auth_oauth(t *testing.T) {
@@ -34,26 +33,26 @@ func Test_auth_oauth(t *testing.T) {
3433
assert.NoError(t, err)
3534

3635
t.Run("happy", func(t *testing.T) {
37-
config.Set(authTypeParameter, nil)
36+
config.Set(AuthTypeParameter, nil)
3837
authenticator.EXPECT().Authenticate().Times(2).Return(nil)
3938
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
4039
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
4140
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
4241
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
43-
err = entryPointDI(mockInvocationContext, &logger, engine, authenticator)
44-
err = entryPointDI(mockInvocationContext, &logger, engine, authenticator)
42+
err = AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
43+
err = AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
4544
assert.NoError(t, err)
4645
})
4746

4847
t.Run("unhappy", func(t *testing.T) {
49-
config.Set(authTypeParameter, nil)
48+
config.Set(AuthTypeParameter, nil)
5049
expectedErr := fmt.Errorf("someting went wrong")
5150
authenticator.EXPECT().Authenticate().Times(1).Return(expectedErr)
5251
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
5352
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
5453
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
5554
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
56-
err = entryPointDI(mockInvocationContext, &logger, engine, authenticator)
55+
err = AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
5756
assert.Equal(t, expectedErr, err)
5857
})
5958
}
@@ -77,27 +76,27 @@ func Test_auth_token(t *testing.T) {
7776
assert.NoError(t, err)
7877

7978
t.Run("happy", func(t *testing.T) {
80-
config.Set(authTypeParameter, auth.AUTH_TYPE_TOKEN)
79+
config.Set(AuthTypeParameter, auth.AUTH_TYPE_TOKEN)
8180
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any())
8281
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
8382
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
8483
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
8584
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
8685

87-
err = entryPointDI(mockInvocationContext, &logger, engine, authenticator)
86+
err = AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
8887
assert.NoError(t, err)
8988
})
9089

9190
t.Run("automatically switch to token when API token is given", func(t *testing.T) {
92-
config.Set(authTypeParameter, nil)
91+
config.Set(AuthTypeParameter, nil)
9392
config.Set(ConfigurationNewAuthenticationToken, "00000000-0000-0000-0000-000000000000")
9493
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any())
9594
mockInvocationContext := mocks.NewMockInvocationContext(mockCtl)
9695
mockInvocationContext.EXPECT().GetConfiguration().Return(config).AnyTimes()
9796
mockInvocationContext.EXPECT().GetEnhancedLogger().Return(&logger).AnyTimes()
9897
mockInvocationContext.EXPECT().GetAnalytics().Return(analytics).AnyTimes()
9998

100-
err = entryPointDI(mockInvocationContext, &logger, engine, authenticator)
99+
err = AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
101100
assert.NoError(t, err)
102101
})
103102
}
@@ -116,7 +115,7 @@ func Test_pat(t *testing.T) {
116115

117116
t.Run("happy", func(t *testing.T) {
118117
config := configuration.NewWithOpts()
119-
config.Set(authTypeParameter, auth.AUTH_TYPE_PAT)
118+
config.Set(AuthTypeParameter, auth.AUTH_TYPE_PAT)
120119
config.Set(ConfigurationNewAuthenticationToken, pat)
121120

122121
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, "some-oauth-token")
@@ -130,7 +129,7 @@ func Test_pat(t *testing.T) {
130129
engine.EXPECT().GetConfiguration().Return(config).AnyTimes()
131130
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any())
132131

133-
err := entryPointDI(mockInvocationContext, &logger, engine, authenticator)
132+
err := AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
134133
assert.NoError(t, err)
135134

136135
assert.Empty(t, config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN))
@@ -139,7 +138,7 @@ func Test_pat(t *testing.T) {
139138

140139
t.Run("invalid pat should fail", func(t *testing.T) {
141140
config := configuration.NewWithOpts()
142-
config.Set(authTypeParameter, auth.AUTH_TYPE_PAT)
141+
config.Set(AuthTypeParameter, auth.AUTH_TYPE_PAT)
143142
config.Set(ConfigurationNewAuthenticationToken, pat)
144143

145144
config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, "some-oauth-token")
@@ -155,7 +154,7 @@ func Test_pat(t *testing.T) {
155154
mockWhoAmIError := fmt.Errorf("mock whoami failure")
156155
engine.EXPECT().InvokeWithConfig(gomock.Any(), gomock.Any()).Return(nil, mockWhoAmIError)
157156

158-
err := entryPointDI(mockInvocationContext, &logger, engine, authenticator)
157+
err := AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
159158
assert.ErrorIs(t, err, mockWhoAmIError)
160159

161160
assert.Empty(t, config.GetString(auth.CONFIG_KEY_OAUTH_TOKEN))
@@ -205,7 +204,7 @@ func Test_clearAllCredentialsBeforeAuth(t *testing.T) {
205204
for _, tc := range testCases {
206205
t.Run(tc.name, func(t *testing.T) {
207206
config := configuration.NewWithOpts()
208-
config.Set(authTypeParameter, tc.authType)
207+
config.Set(AuthTypeParameter, tc.authType)
209208
if tc.authType == auth.AUTH_TYPE_PAT {
210209
config.Set(ConfigurationNewAuthenticationToken, "snyk_uat.12345678.abcdefg-hijklmnop.qrstuvwxyz-123456")
211210
}
@@ -221,7 +220,7 @@ func Test_clearAllCredentialsBeforeAuth(t *testing.T) {
221220

222221
tc.setupMocks()
223222

224-
err := entryPointDI(mockInvocationContext, &logger, engine, authenticator)
223+
err := AuthEntryPointDI(mockInvocationContext, &logger, engine, authenticator)
225224
assert.NoError(t, err)
226225

227226
// Verify both tokens are cleared regardless of auth type

0 commit comments

Comments
 (0)