Skip to content

Commit bea1827

Browse files
authored
feat(config): Introduce concept of configuration dependencies (#418)
1 parent 979c983 commit bea1827

File tree

7 files changed

+327
-59
lines changed

7 files changed

+327
-59
lines changed

pkg/app/app.go

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ import (
2929
)
3030

3131
func defaultFuncOrganizationSlug(engine workflow.Engine, config configuration.Configuration, logger *zerolog.Logger, apiClientFactory func(url string, client *http.Client) api.ApiClient) configuration.DefaultValueFunction {
32+
err := config.AddKeyDependency(configuration.ORGANIZATION_SLUG, configuration.ORGANIZATION)
33+
if err != nil {
34+
logger.Print("Failed to add dependency for ORGANIZATION_SLUG:", err)
35+
}
36+
3237
callback := func(_ configuration.Configuration, existingValue interface{}) (interface{}, error) {
3338
client := engine.GetNetworkAccess().GetHttpClient()
3439
url := config.GetString(configuration.API_URL)
@@ -47,6 +52,11 @@ func defaultFuncOrganizationSlug(engine workflow.Engine, config configuration.Co
4752
}
4853

4954
func defaultFuncOrganization(engine workflow.Engine, config configuration.Configuration, logger *zerolog.Logger, apiClientFactory func(url string, client *http.Client) api.ApiClient) configuration.DefaultValueFunction {
55+
err := config.AddKeyDependency(configuration.ORGANIZATION, configuration.API_URL)
56+
if err != nil {
57+
logger.Print("Failed to add dependency for ORGANIZATION:", err)
58+
}
59+
5060
callback := func(_ configuration.Configuration, existingValue interface{}) (interface{}, error) {
5161
client := engine.GetNetworkAccess().GetHttpClient()
5262
url := config.GetString(configuration.API_URL)
@@ -78,7 +88,16 @@ func defaultFuncOrganization(engine workflow.Engine, config configuration.Config
7888
return callback
7989
}
8090

81-
func defaultFuncApiUrl(_ configuration.Configuration, logger *zerolog.Logger) configuration.DefaultValueFunction {
91+
func defaultFuncApiUrl(globalConfig configuration.Configuration, logger *zerolog.Logger) configuration.DefaultValueFunction {
92+
err := globalConfig.AddKeyDependency(configuration.API_URL, configuration.AUTHENTICATION_TOKEN)
93+
if err != nil {
94+
logger.Print("Failed to add dependency for API_URL:", err)
95+
}
96+
err = globalConfig.AddKeyDependency(configuration.API_URL, auth.CONFIG_KEY_OAUTH_TOKEN)
97+
if err != nil {
98+
logger.Print("Failed to add dependency for API_URL:", err)
99+
}
100+
82101
callback := func(config configuration.Configuration, existingValue interface{}) (interface{}, error) {
83102
urlString := constants.SNYK_DEFAULT_API_URL
84103
authToken := config.GetString(configuration.AUTHENTICATION_TOKEN)
@@ -236,14 +255,20 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio
236255

237256
// set default filesize threshold to 512MB
238257
config.AddDefaultValue(configuration.IN_MEMORY_THRESHOLD_BYTES, configuration.StandardDefaultValueFunction(constants.SNYK_DEFAULT_IN_MEMORY_THRESHOLD_MB))
239-
config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(config, logger))
240258
config.AddDefaultValue(configuration.TEMP_DIR_PATH, defaultTempDirectory(engine, config, logger))
241259

260+
config.AddDefaultValue(configuration.API_URL, defaultFuncApiUrl(config, logger))
261+
262+
err = config.AddKeyDependency(configuration.WEB_APP_URL, configuration.API_URL)
263+
if err != nil {
264+
logger.Print("Failed to add dependency for WEB_APP_URL:", err)
265+
}
266+
242267
config.AddDefaultValue(configuration.WEB_APP_URL, func(c configuration.Configuration, existingValue any) (any, error) {
243268
canonicalApiUrl := c.GetString(configuration.API_URL)
244-
appUrl, err := api.DeriveAppUrl(canonicalApiUrl)
245-
if err != nil {
246-
logger.Print("Failed to determine default value for \"WEB_APP_URL\":", err)
269+
appUrl, appUrlErr := api.DeriveAppUrl(canonicalApiUrl)
270+
if appUrlErr != nil {
271+
logger.Print("Failed to determine default value for \"WEB_APP_URL\":", appUrlErr)
247272
}
248273

249274
return appUrl, nil
@@ -260,6 +285,11 @@ func initConfiguration(engine workflow.Engine, config configuration.Configuratio
260285
}
261286
})
262287

288+
err = config.AddKeyDependency(configuration.IS_FEDRAMP, configuration.API_URL)
289+
if err != nil {
290+
logger.Print("Failed to add dependency for IS_FEDRAMP:", err)
291+
}
292+
263293
config.AddDefaultValue(configuration.IS_FEDRAMP, func(_ configuration.Configuration, existingValue any) (any, error) {
264294
if existingValue == nil {
265295
return api.IsFedramp(config.GetString(configuration.API_URL)), nil

pkg/app/app_test.go

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ import (
3232
"golang.org/x/oauth2/jws"
3333
)
3434

35+
func createOAuthTokenWithAudience(t *testing.T, audience string) string {
36+
t.Helper()
37+
header := &jws.Header{}
38+
claims := &jws.ClaimSet{
39+
Aud: audience,
40+
}
41+
pk, err := rsa.GenerateKey(rand.Reader, 2048)
42+
assert.NoError(t, err)
43+
44+
accessToken, err := jws.Encode(header, claims, pk)
45+
assert.NoError(t, err)
46+
47+
token := oauth2.Token{
48+
AccessToken: accessToken,
49+
}
50+
51+
tokenBytes, err := json.Marshal(token)
52+
assert.NoError(t, err)
53+
54+
return string(tokenBytes)
55+
}
56+
3557
func Test_AddsDefaultFunctionForCustomConfigFiles(t *testing.T) {
3658
t.Run("should load default config files (without given command line)", func(t *testing.T) {
3759
localConfig := configuration.NewWithOpts()
@@ -792,33 +814,12 @@ func Test_auth_oauth(t *testing.T) {
792814
})
793815

794816
t.Run("oauth token change updates api url extraction", func(t *testing.T) {
795-
createOAuthTokenWithAudience := func(audience string) string {
796-
header := &jws.Header{}
797-
claims := &jws.ClaimSet{
798-
Aud: audience,
799-
}
800-
pk, err := rsa.GenerateKey(rand.Reader, 2048)
801-
assert.NoError(t, err)
802-
803-
accessToken, err := jws.Encode(header, claims, pk)
804-
assert.NoError(t, err)
805-
806-
token := oauth2.Token{
807-
AccessToken: accessToken,
808-
}
809-
810-
tokenBytes, err := json.Marshal(token)
811-
assert.NoError(t, err)
812-
813-
return string(tokenBytes)
814-
}
815-
816817
globalConfig := engine.GetConfiguration()
817818
globalConfig.Set(configuration.API_URL, "https://api.snyk.io")
818819
invocationConfig := globalConfig.Clone()
819820

820821
firstAPIURL := "https://api.eu.snyk.io"
821-
firstOAuthToken := createOAuthTokenWithAudience(firstAPIURL)
822+
firstOAuthToken := createOAuthTokenWithAudience(t, firstAPIURL)
822823

823824
// Set auth type to OAuth
824825
invocationConfig.Set(localworkflows.AuthTypeParameter, auth.AUTH_TYPE_OAUTH)
@@ -849,7 +850,7 @@ func Test_auth_oauth(t *testing.T) {
849850

850851
// Now simulate re-authentication with EU
851852
secondAPIURL := "https://api.us.snyk.io"
852-
secondOAuthToken := createOAuthTokenWithAudience(secondAPIURL)
853+
secondOAuthToken := createOAuthTokenWithAudience(t, secondAPIURL)
853854

854855
// Mock second authentication
855856
mockAuthenticator.EXPECT().Authenticate().DoAndReturn(func() error {
@@ -870,3 +871,59 @@ func Test_auth_oauth(t *testing.T) {
870871
assert.Equal(t, secondAPIURL, actualApiUrl, "Second OAuth token should contain US API URL")
871872
})
872873
}
874+
875+
// this tests compares the behavior of the config when it has caching enabled and when it doesn't
876+
func Test_config_compareCachedAndUncachedConfig(t *testing.T) {
877+
tests := []struct {
878+
name string
879+
config configuration.Configuration
880+
}{
881+
{
882+
name: "Cached config",
883+
config: configuration.NewWithOpts(configuration.WithCachingEnabled(time.Hour * 1)),
884+
},
885+
{
886+
name: "Uncached config",
887+
config: configuration.NewWithOpts(),
888+
},
889+
}
890+
891+
for _, tt := range tests {
892+
t.Run(tt.name, func(t *testing.T) {
893+
engine := CreateAppEngineWithOptions(WithConfiguration(tt.config))
894+
assert.NotNil(t, engine)
895+
896+
// Default API URL
897+
assert.Equal(t, constants.SNYK_DEFAULT_API_URL, tt.config.GetString(configuration.API_URL))
898+
899+
// set API URL explicitly
900+
tt.config.Set(configuration.API_URL, "https://api.us.snyk.io")
901+
assert.Equal(t, "https://api.us.snyk.io", tt.config.GetString(configuration.API_URL))
902+
assert.Equal(t, "https://app.us.snyk.io", tt.config.GetString(configuration.WEB_APP_URL))
903+
904+
// set PAT and derive API URL
905+
tt.config.Set(configuration.AUTHENTICATION_TOKEN, createMockPAT(t, `{"h":"api.au.snyk.io"}`))
906+
assert.Equal(t, "https://api.au.snyk.io", tt.config.GetString(configuration.API_URL))
907+
assert.Equal(t, "https://app.au.snyk.io", tt.config.GetString(configuration.WEB_APP_URL))
908+
tt.config.Unset(configuration.AUTHENTICATION_TOKEN)
909+
910+
// set OAuth token and derive API URL
911+
tt.config.Set(auth.CONFIG_KEY_OAUTH_TOKEN, createOAuthTokenWithAudience(t, "https://api.snykgov.io"))
912+
assert.Equal(t, "https://api.snykgov.io", tt.config.GetString(configuration.API_URL))
913+
assert.Equal(t, "https://app.snykgov.io", tt.config.GetString(configuration.WEB_APP_URL))
914+
915+
// set PAT and derive API URL
916+
tt.config.Set(configuration.AUTHENTICATION_TOKEN, createMockPAT(t, `{"h":"api.eu.snyk.io"}`))
917+
assert.Equal(t, "https://api.eu.snyk.io", tt.config.GetString(configuration.API_URL))
918+
assert.Equal(t, "https://app.eu.snyk.io", tt.config.GetString(configuration.WEB_APP_URL))
919+
920+
// unset PAT and OAuth token
921+
tt.config.Unset(configuration.AUTHENTICATION_TOKEN)
922+
tt.config.Unset(auth.CONFIG_KEY_OAUTH_TOKEN)
923+
924+
// exlicitly set API URL is restored
925+
assert.Equal(t, "https://api.us.snyk.io", tt.config.GetString(configuration.API_URL))
926+
assert.Equal(t, "https://app.us.snyk.io", tt.config.GetString(configuration.WEB_APP_URL))
927+
})
928+
}
929+
}

pkg/auth/oauth2authenticator.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ func getRedirectUri(port int) string {
9090
}
9191

9292
func getOAuthConfiguration(config configuration.Configuration) *oauth2.Config {
93-
config.ClearCache()
9493
apiUrl := config.GetString(configuration.API_URL)
9594
appUrl := config.GetString(configuration.WEB_APP_URL)
9695
tokenUrl := apiUrl + "/oauth2/token"

pkg/configuration/configuration.go

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ type Configuration interface {
6262
GetAllKeysThatContainValues(key string) []string
6363
GetKeyType(key string) KeyType
6464

65+
// AddKeyDependency can be used to describe that a certain key and its values actually depend on another value, this can then be used to clear the cache of a key when a depending key changes.
66+
// In words: key depends on dependencyKey.
67+
AddKeyDependency(key string, dependencyKey string) error
68+
6569
// PersistInStorage ensures that when Set is called with the given key, it will be persisted in the config file.
6670
PersistInStorage(key string)
6771
SetStorage(storage Storage)
@@ -103,6 +107,8 @@ type extendedViper struct {
103107

104108
// supportedEnvVars store the env vars that should be supported REGARDLESS of its prefix. e.g. NODE_EXTRA_CA_CERTS
105109
supportedEnvVars []string
110+
111+
interkeyDependencies map[string][]string
106112
}
107113

108114
// StandardDefaultValueFunction is a default value function that returns the default value if the existing value is nil.
@@ -238,10 +244,11 @@ func NewInMemory() Configuration {
238244
func createViperDefaultConfig(opts ...Opts) *extendedViper {
239245
// prepare environment variables
240246
config := &extendedViper{
241-
viper: viper.New(),
242-
alternativeKeys: make(map[string][]string),
243-
defaultValues: make(map[string]DefaultValueFunction),
244-
persistedKeys: make(map[string]bool),
247+
viper: viper.New(),
248+
alternativeKeys: make(map[string][]string),
249+
defaultValues: make(map[string]DefaultValueFunction),
250+
persistedKeys: make(map[string]bool),
251+
interkeyDependencies: make(map[string][]string),
245252
}
246253
config.viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
247254

@@ -331,20 +338,31 @@ func (ev *extendedViper) Clone() Configuration {
331338
items := ev.defaultCache.Items()
332339
clonedCache := cache.NewFrom(cacheTTL, defaultCacheCleanupInterval, items)
333340
evClone.setCache(clonedCache)
341+
evClone.interkeyDependencies = ev.interkeyDependencies
334342
}
335343

336344
return clone
337345
}
338346

347+
func (ev *extendedViper) clearCache(key string) {
348+
if ev.defaultCache != nil {
349+
ev.defaultCache.Delete(key)
350+
351+
// clear cache of all keys that depend on the current key
352+
for _, dependencyKey := range ev.interkeyDependencies[key] {
353+
ev.clearCache(dependencyKey)
354+
}
355+
}
356+
}
357+
339358
// Set sets a configuration value.
340359
func (ev *extendedViper) Set(key string, value interface{}) {
341360
ev.mutex.Lock()
342361
localStorage := ev.storage
343362
isPersisted := ev.persistedKeys[key]
344363
ev.viper.Set(key, value)
345-
if ev.defaultCache != nil {
346-
ev.defaultCache.Delete(key)
347-
}
364+
ev.clearCache(key)
365+
348366
ev.mutex.Unlock()
349367

350368
if localStorage != nil && isPersisted {
@@ -863,6 +881,39 @@ func (ev *extendedViper) getCacheSettings() (bool, time.Duration, error) {
863881
return enabled, duration, nil
864882
}
865883

884+
func (ev *extendedViper) detectCircularDependency(key string, dependencyKey string) error {
885+
circularDependencyError := errors.New("circular dependency detected")
886+
887+
if key == dependencyKey {
888+
return circularDependencyError
889+
}
890+
891+
for _, existingKey := range ev.interkeyDependencies[key] {
892+
if existingKey == dependencyKey {
893+
return circularDependencyError
894+
}
895+
896+
if err := ev.detectCircularDependency(existingKey, dependencyKey); err != nil {
897+
return err
898+
}
899+
}
900+
901+
return nil
902+
}
903+
904+
func (ev *extendedViper) AddKeyDependency(key string, dependencyKey string) error {
905+
ev.mutex.Lock()
906+
defer ev.mutex.Unlock()
907+
908+
// detect circular dependencies
909+
if err := ev.detectCircularDependency(key, dependencyKey); err != nil {
910+
return err
911+
}
912+
913+
ev.interkeyDependencies[dependencyKey] = append(ev.interkeyDependencies[dependencyKey], key)
914+
return nil
915+
}
916+
866917
func toBool(result interface{}) (bool, error) {
867918
switch v := result.(type) {
868919
case bool:

0 commit comments

Comments
 (0)