44from django .contrib .auth import get_user_model
55from django .contrib .auth .backends import ModelBackend
66from django .contrib .auth .models import Group
7- from django .core .exceptions import (ImproperlyConfigured , ObjectDoesNotExist ,
8- PermissionDenied )
7+ from django .core .exceptions import (
8+ ImproperlyConfigured ,
9+ ObjectDoesNotExist ,
10+ PermissionDenied ,
11+ )
912
1013from django_auth_adfs import signals
1114from django_auth_adfs .config import provider_config , settings
1518
1619
1720class AdfsBaseBackend (ModelBackend ):
18-
1921 def _ms_request (self , action , url , data = None , ** kwargs ):
2022 """
2123 Make a Microsoft Entra/GraphQL request
@@ -36,7 +38,10 @@ def _ms_request(self, action, url, data=None, **kwargs):
3638 if response .status_code == 400 :
3739 if response .json ().get ("error_description" , "" ).startswith ("AADSTS50076" ):
3840 raise MFARequired
39- logger .error ("ADFS server returned an error: %s" , response .json ()["error_description" ])
41+ logger .error (
42+ "ADFS server returned an error: %s" ,
43+ response .json ()["error_description" ],
44+ )
4045 raise PermissionDenied
4146
4247 if response .status_code != 200 :
@@ -47,16 +52,18 @@ def _ms_request(self, action, url, data=None, **kwargs):
4752 def exchange_auth_code (self , authorization_code , request ):
4853 logger .debug ("Received authorization code: %s" , authorization_code )
4954 data = {
50- ' grant_type' : ' authorization_code' ,
51- ' client_id' : settings .CLIENT_ID ,
52- ' redirect_uri' : provider_config .redirect_uri (request ),
53- ' code' : authorization_code ,
55+ " grant_type" : " authorization_code" ,
56+ " client_id" : settings .CLIENT_ID ,
57+ " redirect_uri" : provider_config .redirect_uri (request ),
58+ " code" : authorization_code ,
5459 }
5560 if settings .CLIENT_SECRET :
56- data [' client_secret' ] = settings .CLIENT_SECRET
61+ data [" client_secret" ] = settings .CLIENT_SECRET
5762
5863 logger .debug ("Getting access token at: %s" , provider_config .token_endpoint )
59- response = self ._ms_request (provider_config .session .post , provider_config .token_endpoint , data )
64+ response = self ._ms_request (
65+ provider_config .session .post , provider_config .token_endpoint , data
66+ )
6067 adfs_response = response .json ()
6168 return adfs_response
6269
@@ -79,11 +86,13 @@ def get_obo_access_token(self, access_token):
7986 "requested_token_use" : "on_behalf_of" ,
8087 }
8188 if provider_config .token_endpoint .endswith ("/v2.0/token" ):
82- data ["scope" ] = ' GroupMember.Read.All'
89+ data ["scope" ] = " GroupMember.Read.All"
8390 else :
84- data ["resource" ] = ' https://graph.microsoft.com'
91+ data ["resource" ] = " https://graph.microsoft.com"
8592
86- response = self ._ms_request (provider_config .session .get , provider_config .token_endpoint , data )
93+ response = self ._ms_request (
94+ provider_config .session .get , provider_config .token_endpoint , data
95+ )
8796 obo_access_token = response .json ()["access_token" ]
8897 logger .debug ("Received OBO access token: %s" , obo_access_token )
8998 return obo_access_token
@@ -117,8 +126,10 @@ def get_group_memberships_from_ms_graph(self, obo_access_token):
117126 Returns:
118127 claim_groups (list): List of the users group memberships
119128 """
120- graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group" .format (
121- provider_config .msgraph_endpoint
129+ graph_url = (
130+ "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group" .format (
131+ provider_config .msgraph_endpoint
132+ )
122133 )
123134 headers = {"Authorization" : "Bearer {}" .format (obo_access_token )}
124135 response = self ._ms_request (
@@ -147,25 +158,25 @@ def validate_access_token(self, access_token):
147158 # Explicit is better then implicit and it protects against
148159 # changes in the defaults the jwt module uses.
149160 options = {
150- ' verify_signature' : True ,
151- ' verify_exp' : True ,
152- ' verify_nbf' : True ,
153- ' verify_iat' : True ,
154- ' verify_aud' : True ,
155- ' verify_iss' : True ,
156- ' require_exp' : False ,
157- ' require_iat' : False ,
158- ' require_nbf' : False
161+ " verify_signature" : True ,
162+ " verify_exp" : True ,
163+ " verify_nbf" : True ,
164+ " verify_iat" : True ,
165+ " verify_aud" : True ,
166+ " verify_iss" : True ,
167+ " require_exp" : False ,
168+ " require_iat" : False ,
169+ " require_nbf" : False ,
159170 }
160171 # Validate token and return claims
161172 return jwt .decode (
162173 access_token ,
163174 key = key ,
164- algorithms = [' RS256' , ' RS384' , ' RS512' ],
175+ algorithms = [" RS256" , " RS384" , " RS512" ],
165176 audience = settings .AUDIENCE ,
166177 issuer = provider_config .issuer ,
167178 options = options ,
168- leeway = settings .JWT_LEEWAY
179+ leeway = settings .JWT_LEEWAY ,
169180 )
170181 except jwt .ExpiredSignatureError as error :
171182 logger .info ("Signature has expired: %s" , error )
@@ -175,7 +186,7 @@ def validate_access_token(self, access_token):
175186 if idx < len (provider_config .signing_keys ) - 1 :
176187 continue
177188 else :
178- logger .info (' Error decoding signature: %s' , error )
189+ logger .info (" Error decoding signature: %s" , error )
179190 raise PermissionDenied
180191 except jwt .InvalidTokenError as error :
181192 logger .info (str (error ))
@@ -187,12 +198,8 @@ def process_access_token(self, access_token, adfs_response=None):
187198
188199 logger .debug ("Received access token: %s" , access_token )
189200 claims = self .validate_access_token (access_token )
190- if (
191- settings .BLOCK_GUEST_USERS
192- and claims .get ('tid' )
193- != settings .TENANT_ID
194- ):
195- logger .info ('Guest user denied' )
201+ if settings .BLOCK_GUEST_USERS and claims .get ("tid" ) != settings .TENANT_ID :
202+ logger .info ("Guest user denied" )
196203 raise PermissionDenied
197204 if not claims :
198205 raise PermissionDenied
@@ -204,10 +211,7 @@ def process_access_token(self, access_token, adfs_response=None):
204211 self .update_user_flags (user , claims , groups )
205212
206213 signals .post_authenticate .send (
207- sender = self ,
208- user = user ,
209- claims = claims ,
210- adfs_response = adfs_response
214+ sender = self , user = user , claims = claims , adfs_response = adfs_response
211215 )
212216
213217 user .full_clean ()
@@ -235,7 +239,9 @@ def process_user_groups(self, claims, access_token):
235239 if settings .GROUPS_CLAIM in claims :
236240 groups = claims [settings .GROUPS_CLAIM ]
237241 if not isinstance (groups , list ):
238- groups = [groups , ]
242+ groups = [
243+ groups ,
244+ ]
239245 elif (
240246 settings .TENANT_ID != "adfs"
241247 and "_claim_names" in claims
@@ -244,8 +250,10 @@ def process_user_groups(self, claims, access_token):
244250 obo_access_token = self .get_obo_access_token (access_token )
245251 groups = self .get_group_memberships_from_ms_graph (obo_access_token )
246252 else :
247- logger .debug ("The configured groups claim %s was not found in the access token" ,
248- settings .GROUPS_CLAIM )
253+ logger .debug (
254+ "The configured groups claim %s was not found in the access token" ,
255+ settings .GROUPS_CLAIM ,
256+ )
249257
250258 return groups
251259
@@ -264,19 +272,21 @@ def create_user(self, claims):
264272 guest_username_claim = settings .GUEST_USERNAME_CLAIM
265273 usermodel = get_user_model ()
266274
267- iss = claims .get (' iss' )
268- idp = claims .get (' idp' , iss )
275+ iss = claims .get (" iss" )
276+ idp = claims .get (" idp" , iss )
269277 if (
270278 guest_username_claim
271279 and not claims .get (username_claim )
272280 and not settings .BLOCK_GUEST_USERS
273- and (claims .get (' tid' ) != settings .TENANT_ID or iss != idp )
281+ and (claims .get (" tid" ) != settings .TENANT_ID or iss != idp )
274282 ):
275283 username_claim = guest_username_claim
276284
277285 if not claims .get (username_claim ):
278- logger .error ("User claim's doesn't have the claim '%s' in his claims: %s" %
279- (username_claim , claims ))
286+ logger .error (
287+ "User claim's doesn't have the claim '%s' in his claims: %s"
288+ % (username_claim , claims )
289+ )
280290 raise PermissionDenied
281291
282292 userdata = {usermodel .USERNAME_FIELD : claims [username_claim ]}
@@ -288,7 +298,10 @@ def create_user(self, claims):
288298 user = usermodel .objects .create (** userdata )
289299 logger .debug ("User '%s' has been created." , claims [username_claim ])
290300 else :
291- logger .debug ("User '%s' doesn't exist and creating users is disabled." , claims [username_claim ])
301+ logger .debug (
302+ "User '%s' doesn't exist and creating users is disabled." ,
303+ claims [username_claim ],
304+ )
292305 raise PermissionDenied
293306 if not user .password :
294307 user .set_unusable_password ()
@@ -308,27 +321,47 @@ def update_user_attributes(self, user, claims, claim_mapping=None):
308321 """
309322 if claim_mapping is None :
310323 claim_mapping = settings .CLAIM_MAPPING
311- required_fields = [field .name for field in user ._meta .get_fields () if getattr (field , 'blank' , True ) is False ]
324+ required_fields = [
325+ field .name
326+ for field in user ._meta .get_fields ()
327+ if getattr (field , "blank" , True ) is False
328+ ]
312329
313330 for field , claim in claim_mapping .items ():
314331 if hasattr (user , field ) or user ._meta .fields_map .get (field ):
315332 if not isinstance (claim , dict ):
316333 if claim in claims :
317334 setattr (user , field , claims [claim ])
318- logger .debug ("Attribute '%s' for instance '%s' was set to '%s'." , field , user , claims [claim ])
335+ logger .debug (
336+ "Attribute '%s' for instance '%s' was set to '%s'." ,
337+ field ,
338+ user ,
339+ claims [claim ],
340+ )
319341 else :
320342 if field in required_fields :
321343 msg = "Claim not found in access token: '{}'. Check ADFS claims mapping."
322344 raise ImproperlyConfigured (msg .format (claim ))
323345 else :
324- logger .warning ("Claim '%s' for field '%s' was not found in "
325- "the access token for instance '%s'. "
326- "Field is not required and will be left empty" , claim , field , user )
346+ logger .warning (
347+ "Claim '%s' for field '%s' was not found in "
348+ "the access token for instance '%s'. "
349+ "Field is not required and will be left empty" ,
350+ claim ,
351+ field ,
352+ user ,
353+ )
327354 else :
328355 try :
329- self .update_user_attributes (getattr (user , field ), claims , claim_mapping = claim )
356+ self .update_user_attributes (
357+ getattr (user , field ), claims , claim_mapping = claim
358+ )
330359 except ObjectDoesNotExist :
331- logger .warning ("Object for field '{}' does not exist for: '{}'." .format (field , user ))
360+ logger .warning (
361+ "Object for field '{}' does not exist for: '{}'." .format (
362+ field , user
363+ )
364+ )
332365
333366 else :
334367 msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping."
@@ -358,10 +391,13 @@ def update_user_groups(self, user, claim_groups):
358391 # bulk_create could have been used here but we want to send signals.
359392 new_claimed_groups = [
360393 Group .objects .get_or_create (name = name )[0 ]
361- for name in claim_groups if name not in existing_claimed_group_names
394+ for name in claim_groups
395+ if name not in existing_claimed_group_names
362396 ]
363397 # Associate the users to all claimed groups
364- user .groups .set (tuple (existing_claimed_groups ) + tuple (new_claimed_groups ))
398+ user .groups .set (
399+ tuple (existing_claimed_groups ) + tuple (new_claimed_groups )
400+ )
365401 else :
366402 # Associate the user to only existing claimed groups
367403 user .groups .set (existing_claimed_groups )
@@ -381,23 +417,42 @@ def update_user_flags(self, user, claims, claim_groups):
381417 if not isinstance (group , list ):
382418 group = [group ]
383419
384- if any (group_list_item in claim_groups for group_list_item in group ):
420+ if any (
421+ group_list_item in claim_groups for group_list_item in group
422+ ):
385423 value = True
386424 else :
387425 value = False
388426 setattr (user , flag , value )
389- logger .debug ("Attribute '%s' for user '%s' was set to '%s'." , flag , user , value )
427+ logger .debug (
428+ "Attribute '%s' for user '%s' was set to '%s'." ,
429+ flag ,
430+ user ,
431+ value ,
432+ )
390433 else :
391434 msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
392435 raise ImproperlyConfigured (msg .format (flag ))
393436
394437 for field , claim in settings .BOOLEAN_CLAIM_MAPPING .items ():
395438 if hasattr (user , field ):
396439 bool_val = False
397- if claim in claims and str (claims [claim ]).lower () in ['y' , 'yes' , 't' , 'true' , 'on' , '1' ]:
440+ if claim in claims and str (claims [claim ]).lower () in [
441+ "y" ,
442+ "yes" ,
443+ "t" ,
444+ "true" ,
445+ "on" ,
446+ "1" ,
447+ ]:
398448 bool_val = True
399449 setattr (user , field , bool_val )
400- logger .debug ("Attribute '%s' for user '%s' was set to '%s'." , field , user , bool_val )
450+ logger .debug (
451+ "Attribute '%s' for user '%s' was set to '%s'." ,
452+ field ,
453+ user ,
454+ bool_val ,
455+ )
401456 else :
402457 msg = "User model has no field named '{}'. Check ADFS boolean claims mapping."
403458 raise ImproperlyConfigured (msg .format (field ))
@@ -411,8 +466,10 @@ class AdfsAuthCodeBackend(AdfsBaseBackend):
411466
412467 def authenticate (self , request = None , authorization_code = None , ** kwargs ):
413468 # If there's no token or code, we pass control to the next authentication backend
414- if authorization_code is None or authorization_code == '' :
415- logger .debug ("Authentication backend was called but no authorization code was received" )
469+ if authorization_code is None or authorization_code == "" :
470+ logger .debug (
471+ "Authentication backend was called but no authorization code was received"
472+ )
416473 return
417474
418475 # If loaded data is too old, reload it again
@@ -435,8 +492,10 @@ def authenticate(self, request=None, access_token=None, **kwargs):
435492 provider_config .load_config ()
436493
437494 # If there's no token or code, we pass control to the next authentication backend
438- if access_token is None or access_token == '' :
439- logger .debug ("Authentication backend was called but no access token was received" )
495+ if access_token is None or access_token == "" :
496+ logger .debug (
497+ "Authentication backend was called but no access token was received"
498+ )
440499 return
441500
442501 access_token = access_token .decode ()
@@ -445,5 +504,6 @@ def authenticate(self, request=None, access_token=None, **kwargs):
445504
446505
447506class AdfsBackend (AdfsAuthCodeBackend ):
448- """ Backwards compatible class name """
507+ """Backwards compatible class name"""
508+
449509 pass
0 commit comments