From d22fd7f1373e0845108d9dae9434e6fa17151e8a Mon Sep 17 00:00:00 2001 From: Yago Hernandez Date: Sat, 15 Nov 2025 17:37:46 +0000 Subject: [PATCH] fix: browser_azure_credentials saml assertion --- .../plugin/browser_azure_credentials_provider.py | 6 +++--- test/unit/plugin/data/browser_azure_data.py | 4 ++-- test/unit/plugin/test_browser_azure_credentials_provider.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/redshift_connector/plugin/browser_azure_credentials_provider.py b/redshift_connector/plugin/browser_azure_credentials_provider.py index 9f70ee5..2fa07f6 100644 --- a/redshift_connector/plugin/browser_azure_credentials_provider.py +++ b/redshift_connector/plugin/browser_azure_credentials_provider.py @@ -176,7 +176,7 @@ def fetch_saml_response(self: "BrowserAzureCredentialsProvider", token): if missing_padding: saml_assertion += "=" * missing_padding - return str(base64.urlsafe_b64decode(saml_assertion)) + return base64.urlsafe_b64decode(saml_assertion).decode("utf-8") # SAML Response is required to be sent to base class. We need to provide a minimum of: # samlp:Response XML tag with xmlns:samlp protocol value @@ -189,10 +189,10 @@ def wrap_and_encode_assertion(self: "BrowserAzureCredentialsProvider", saml_asse '' "" "{saml_assertion}" - "".format(saml_assertion=saml_assertion[2:-1]) + "".format(saml_assertion=saml_assertion) ) - return str(base64.b64encode(saml_response.encode("utf-8")))[2:-1] + return base64.b64encode(saml_response.encode("utf-8")).decode("utf-8") def run_server( self: "BrowserAzureCredentialsProvider", listen_socket: socket.socket, idp_response_timeout: int, state: str diff --git a/test/unit/plugin/data/browser_azure_data.py b/test/unit/plugin/data/browser_azure_data.py index f1faa62..19e85ed 100644 --- a/test/unit/plugin/data/browser_azure_data.py +++ b/test/unit/plugin/data/browser_azure_data.py @@ -56,7 +56,7 @@ ) -saml_response = b"my_access_token" +saml_response = b"\"my_access_token\"".decode("utf-8") valid_json_response: dict = { "token_type": "Bearer", @@ -64,7 +64,7 @@ "ext_expires_in": "3599", "expires_on": "1602782647", "resource": "spn:1234567891011121314151617181920", - "access_token": "bXlfYWNjZXNzX3Rva2Vu", # base64.urlsafe_64encode(saml_response) + "access_token": "Im15X2FjY2Vzc190b2tlbiI=", # base64.urlsafe_64encode(saml_response) "issued_token_type": "urn:ietf:params:oauth:token-type:saml2", "refresh_token": "my_refresh_token", "id_token": "my_id_token", diff --git a/test/unit/plugin/test_browser_azure_credentials_provider.py b/test/unit/plugin/test_browser_azure_credentials_provider.py index 6020746..1531c17 100644 --- a/test/unit/plugin/test_browser_azure_credentials_provider.py +++ b/test/unit/plugin/test_browser_azure_credentials_provider.py @@ -247,7 +247,7 @@ def mock_get_json() -> typing.Dict: mocked_post.return_value = mock_get_resp() saml_assertion: str = bacp.fetch_saml_response(token="blah") - assert str(browser_azure_data.saml_response) == saml_assertion + assert browser_azure_data.saml_response == saml_assertion malformed_json_responses: typing.List[typing.Tuple[typing.Optional[typing.Dict], str]] = [