diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 680c1bc5..68b0e84e 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -13,7 +13,7 @@ import base64 import sys import functools -import random +import secrets import string import hashlib @@ -277,8 +277,9 @@ def _scope_set(scope): def _generate_pkce_code_verifier(length=43): assert 43 <= length <= 128 + alphabet = string.ascii_letters + string.digits + "-._~" verifier = "".join( # https://tools.ietf.org/html/rfc7636#section-4.1 - random.sample(string.ascii_letters + string.digits + "-._~", length)) + secrets.choice(alphabet) for _ in range(length)) code_challenge = ( # https://tools.ietf.org/html/rfc7636#section-4.2 base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) @@ -488,7 +489,7 @@ def initiate_auth_code_flow( raise ValueError('response_type="token ..." is not allowed') pkce = _generate_pkce_code_verifier() flow = { # These data are required by obtain_token_by_auth_code_flow() - "state": state or "".join(random.sample(string.ascii_letters, 16)), + "state": state or secrets.token_urlsafe(16), "redirect_uri": redirect_uri, "scope": scope, } diff --git a/msal/oauth2cli/oidc.py b/msal/oauth2cli/oidc.py index 01ee7894..faa2f966 100644 --- a/msal/oauth2cli/oidc.py +++ b/msal/oauth2cli/oidc.py @@ -1,8 +1,7 @@ import json import base64 import time -import random -import string +import secrets import warnings import hashlib import logging @@ -238,7 +237,7 @@ def initiate_auth_code_flow( # Here we just automatically add it. If the caller do not want id_token, # they should simply go with oauth2.Client. _scope.append("openid") - nonce = "".join(random.sample(string.ascii_letters, 16)) + nonce = secrets.token_urlsafe(16) flow = super(Client, self).initiate_auth_code_flow( scope=_scope, nonce=_nonce_hash(nonce), **kwargs) flow["nonce"] = nonce diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 297dfeb5..44e4658b 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -1,7 +1,80 @@ +import string + from tests import unittest import msal from msal import oauth2cli +from msal.oauth2cli.oauth2 import _generate_pkce_code_verifier + + +class TestCsprngUsage(unittest.TestCase): + """Tests that security-critical parameters use cryptographically secure randomness.""" + + # RFC 7636 ยง4.1: code_verifier = 43*128unreserved + _PKCE_ALPHABET = set(string.ascii_letters + string.digits + "-._~") + + def test_pkce_code_verifier_contains_only_valid_characters(self): + for _ in range(50): + result = _generate_pkce_code_verifier() + self.assertTrue( + set(result["code_verifier"]).issubset(self._PKCE_ALPHABET), + "code_verifier contains invalid characters") + + def test_pkce_code_verifier_has_correct_default_length(self): + result = _generate_pkce_code_verifier() + self.assertEqual(len(result["code_verifier"]), 43) + + def test_pkce_code_verifier_respects_custom_length(self): + for length in (43, 64, 128): + result = _generate_pkce_code_verifier(length) + self.assertEqual(len(result["code_verifier"]), length) + + def test_pkce_code_verifier_can_have_repeated_characters(self): + """secrets.choice() samples with replacement, unlike the old random.sample().""" + result = _generate_pkce_code_verifier(128) + code_verifier = result["code_verifier"] + self.assertLess(len(set(code_verifier)), len(code_verifier), + "At length 128 with a 66-char alphabet, repeated chars are expected") + + def test_pkce_code_verifier_is_not_deterministic(self): + results = {_generate_pkce_code_verifier()["code_verifier"] for _ in range(10)} + self.assertGreater(len(results), 1, "code_verifier should not be deterministic") + + def test_oauth2_state_is_url_safe_and_unpredictable(self): + """State generated by initiate_auth_code_flow should be URL-safe.""" + from msal.oauth2cli.oauth2 import Client + client = Client( + {"authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token"}, + client_id="test_client") + states = set() + for _ in range(10): + flow = client.initiate_auth_code_flow( + redirect_uri="http://localhost", scope=["openid"], + response_mode="form_post") + state = flow["state"] + self.assertRegex(state, r'^[A-Za-z0-9_-]+$', + "state should be URL-safe") + states.add(state) + self.assertGreater(len(states), 1, "state should not be deterministic") + + def test_oidc_nonce_is_url_safe_and_unpredictable(self): + """Nonce generated by OIDC initiate_auth_code_flow should be URL-safe.""" + from msal.oauth2cli.oidc import Client + client = Client( + {"authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token"}, + client_id="test_client") + nonces = set() + for _ in range(10): + flow = client.initiate_auth_code_flow( + redirect_uri="http://localhost", scope=["openid"], + response_mode="form_post") + nonce = flow["nonce"] + self.assertRegex(nonce, r'^[A-Za-z0-9_-]+$', + "nonce should be URL-safe") + nonces.add(nonce) + self.assertGreater(len(nonces), 1, "nonce should not be deterministic") class TestIdToken(unittest.TestCase):