From 9823509ba8f8a39f6a340f9880879657617d3d3e Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Mon, 22 Dec 2025 20:43:09 +0100 Subject: [PATCH 01/38] initial, minimal working version of WebSocketCopilotTarget --- doc/api.rst | 1 + pyrit/prompt_target/__init__.py | 2 + .../prompt_target/websocket_copilot_target.py | 297 ++++++++++++++++++ websocket_copilot_simple_example.py | 31 ++ 4 files changed, 331 insertions(+) create mode 100644 pyrit/prompt_target/websocket_copilot_target.py create mode 100644 websocket_copilot_simple_example.py diff --git a/doc/api.rst b/doc/api.rst index 1b9bdc775..cd812de72 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -500,6 +500,7 @@ API Reference PromptTarget RealtimeTarget TextTarget + WebSocketCopilotTarget :py:mod:`pyrit.score` ===================== diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index cdbbdb0ff..eee3ff6cb 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -37,6 +37,7 @@ from pyrit.prompt_target.playwright_copilot_target import CopilotType, PlaywrightCopilotTarget from pyrit.prompt_target.prompt_shield_target import PromptShieldTarget from pyrit.prompt_target.text_target import TextTarget +from pyrit.prompt_target.websocket_copilot_target import WebSocketCopilotTarget __all__ = [ "AzureBlobStorageTarget", @@ -66,4 +67,5 @@ "PromptTarget", "RealtimeTarget", "TextTarget", + "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py new file mode 100644 index 000000000..f15994f56 --- /dev/null +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging +import os +import uuid +from enum import Enum +from typing import Optional + +import websockets + +from pyrit.models import Message, construct_response_from_request +from pyrit.prompt_target import PromptTarget, limit_requests_per_minute + +logger = logging.getLogger(__name__) + +""" +Useful links: +https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py +https://labs.zenity.io/p/access-copilot-m365-terminal +""" + + +class CopilotMessageType(Enum): + """Enumeration for Copilot WebSocket message types.""" + + UNKNOWN = -1 + NEXT_DATA_FRAME = 1 # streaming Copilot responses + LAST_DATA_FRAME = 2 # the last data frame with final content + USER_PROMPT = 4 + PING = 6 + + +class WebSocketCopilotTarget(PromptTarget): + """ + A WebSocket-based prompt target for Microsoft Copilot integration. + + This target enables communication with Microsoft Copilot through a WebSocket connection. + Currently, authentication requires manually extracting a WebSocket URL from an active browser session. + In the future, more flexible authentication mechanisms will be added. + + To obtain the WebSocket URL: + 1. Ensure you are logged into Microsoft 365 with access to Copilot + 2. Navigate to https://m365.cloud.microsoft/chat or open Copilot in https://teams.microsoft.com/v2 + 3. Open browser developer tools and switch to the Network tab + 4. Begin typing or send a message to Copilot to establish the WebSocket connection + 5. Search the network requests for "chathub", "conversation", or "access_token" + 6. Identify the WebSocket connection (look for WS protocol) and copy its full URL + + Warning: + All target instances using the same `WEBSOCKET_URL` will share a single conversation session. + Only works with licensed Microsoft 365 Copilot. The free Copilot version is not compatible. + """ + + # TODO: add more flexible auth, use puppeteer? https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L248 + # TODO: add useful message for: "Error during WebSocket communication: server rejected WebSocket connection: HTTP 401" + + SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? + + # TODO: implement timeouts and retries + MAX_WAIT_TIME_SECONDS: int = 300 + POLL_INTERVAL_MS: int = 2000 + + def __init__( + self, + *, + verbose: bool = False, + max_requests_per_minute: Optional[int] = None, + model_name: str = "copilot", + ) -> None: + """ + Initialize the WebSocketCopilotTarget. + + Args: + verbose (bool): Enable verbose logging. Defaults to False. + max_requests_per_minute (int, Optional): Maximum number of requests per minute. + model_name (str): The model name. Defaults to "copilot". + + Raises: + ValueError: If WebSocket URL is not provided as env variable. + """ + self._websocket_url = os.getenv("WEBSOCKET_URL") + if not self._websocket_url: + raise ValueError("WebSocket URL must be provided through the WEBSOCKET_URL environment variable") + + if not "ConversationId=" in self._websocket_url: + raise ValueError("`ConversationId` parameter not found in URL.") + self._conversation_id = self._websocket_url.split("ConversationId=")[1].split("&")[0] + + if not "X-SessionId=" in self._websocket_url: + raise ValueError("`X-SessionId` parameter not found in URL.") + self._session_id = self._websocket_url.split("X-SessionId=")[1].split("&")[0] + + super().__init__( + verbose=verbose, + max_requests_per_minute=max_requests_per_minute, + endpoint=self._websocket_url.split("?")[0], # wss://substrate.office.com/m365Copilot/Chathub/... + model_name=model_name, + ) + + if self._verbose: + logger.info(f"WebSocketCopilotTarget initialized with conversation_id: {self._conversation_id}") + logger.info(f"Session ID: {self._session_id}") + + @staticmethod + def _dict_to_websocket(data: dict) -> str: + # Produce the smallest possible JSON string, followed by record separator + return json.dumps(data, separators=(",", ":")) + "\x1e" + + @staticmethod + def _parse_message(raw_message: str) -> tuple[int, str, dict]: + """ + Extract actionable content from raw WebSocket frames. + + Args: + raw_message (str): The raw WebSocket message string. + + Returns: + tuple: (message_type, content_text, full_data) + """ + try: + # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding + message = message = raw_message.split("\x1e")[0] # record separator + if not message: + return (-1, "", {}) + + data = json.loads(message) + msg_type = data.get("type", -1) + + if msg_type == 6: # PING + return (6, "", data) + + if msg_type == 2: # LAST_DATA_FRAME + item = data.get("item", {}) + if item: + messages = item.get("messages", []) + if messages: + for msg in reversed(messages): + if msg.get("author") == "bot": + text = msg.get("text", "") + if text: + return (2, text, data) + # TODO: maybe treat this as error? + logger.warning("LAST_DATA_FRAME received but no parseable content found.") + return (2, "", data) + + if msg_type == 1: # NEXT_DATA_FRAME + # Streamed updates are not needed for this target + return (1, "", data) + + return (msg_type, "", data) + + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON message: {str(e)}") + return (-1, "", {}) + + def _build_prompt_message(self, prompt: str) -> dict: + return { + "arguments": [ + { + "source": "officeweb", # TODO: support 'teamshub' as well + # TODO: not sure whether to uuid.uuid4() or use a static like it's done in power-pwn + # https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L156 + "clientCorrelationId": str(uuid.uuid4()), + "sessionId": self._session_id, + "optionsSets": [ + "enterprise_flux_web", + "enterprise_flux_work", + "enable_request_response_interstitials", + "enterprise_flux_image_v1", + "enterprise_toolbox_with_skdsstore", + "enterprise_toolbox_with_skdsstore_search_message_extensions", + "enable_ME_auth_interstitial", + "skdsstorethirdparty", + "enable_confirmation_interstitial", + "enable_plugin_auth_interstitial", + "enable_response_action_processing", + "enterprise_flux_work_gptv", + "enterprise_flux_work_code_interpreter", + "enable_batch_token_processing", + ], + "options": {}, + "allowedMessageTypes": [ + "Chat", + "Suggestion", + "InternalSearchQuery", + "InternalSearchResult", + "Disengaged", + "InternalLoaderMessage", + "RenderCardRequest", + "AdsQuery", + "SemanticSerp", + "GenerateContentQuery", + "SearchQuery", + "ConfirmationCard", + "AuthError", + "DeveloperLogs", + ], + "sliceIds": [], + # TODO: enable using agents https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L192 + "threadLevelGptId": {}, + "conversationId": self._conversation_id, + "traceId": str(uuid.uuid4()).replace("-", ""), # TODO: same case as clientCorrelationId + "isStartOfSession": 0, + "productThreadType": "Office", + "clientInfo": {"clientPlatform": "web"}, + "message": { + "author": "user", + "inputMethod": "Keyboard", + "text": prompt, + "entityAnnotationTypes": ["People", "File", "Event", "Email", "TeamsMessage"], + "requestId": str(uuid.uuid4()).replace("-", ""), + "locationInfo": {"timeZoneOffset": 0, "timeZone": "UTC"}, + "locale": "en-US", + "messageType": "Chat", + "experienceType": "Default", + }, + "plugins": [], # TODO: support enabling some plugins? + } + ], + "invocationId": "0", # TODO: should be dynamic? + "target": "chat", + "type": 4, + } + + async def _connect_and_send(self, prompt: str) -> str: + protocol_msg = {"protocol": "json", "version": 1} + prompt_dict = self._build_prompt_message(prompt) + + inputs = [protocol_msg, prompt_dict] + last_response = "" + + async with websockets.connect(self._websocket_url) as websocket: + for input_msg in inputs: + payload = self._dict_to_websocket(input_msg) + is_user_input = input_msg.get("type") == 4 # USER_PROMPT + + await websocket.send(payload) + + stop_polling = False + while not stop_polling: + response = await websocket.recv() + msg_type, content, data = self._parse_message(response) + + if ( + msg_type in (-1, 2) # UNKNOWN or LAST_DATA_FRAME + or msg_type == 6 + and not is_user_input + ): + stop_polling = True + + if msg_type == 2: # LAST_DATA_FRAME - final response + last_response = content + elif msg_type == -1: # UNKNOWN/NONE + logger.debug("Received unknown or empty message type.") + + return last_response + + def _validate_request(self, *, message: Message) -> None: + n_pieces = len(message.message_pieces) + if n_pieces != 1: + raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") + + piece_type = message.message_pieces[0].converted_value_data_type + if piece_type != "text": + raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") + + @limit_requests_per_minute + async def send_prompt_async(self, *, message: Message) -> list[Message]: + """ + Asynchronously send a message to Microsoft Copilot using WebSocket. + + Args: + message (Message): A message to be sent to the target. + + Returns: + list[Message]: A list containing the response from Copilot. + + Raises: + RuntimeError: If an error occurs during WebSocket communication. + """ + self._validate_request(message=message) + request_piece = message.message_pieces[0] + + try: + prompt_text = request_piece.converted_value + response_text = await self._connect_and_send(prompt_text) + + response_entry = construct_response_from_request( + request=request_piece, response_text_pieces=[response_text] + ) + + return [response_entry] + + except Exception as e: + raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e diff --git a/websocket_copilot_simple_example.py b/websocket_copilot_simple_example.py new file mode 100644 index 000000000..a1e13831a --- /dev/null +++ b/websocket_copilot_simple_example.py @@ -0,0 +1,31 @@ +""" +# TODO +THIS WILL BE REMOVED after proper unit tests are in place :) +""" + +import asyncio + +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import WebSocketCopilotTarget +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + + +async def main(): + await initialize_pyrit_async(memory_db_type=IN_MEMORY) + target = WebSocketCopilotTarget() + + message_piece = MessagePiece( + role="user", + original_value="say only one random word", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + responses = await target.send_prompt_async(message=message) + for response in responses: + print(f"{response.get_value()}") + + +if __name__ == "__main__": + asyncio.run(main()) From 62fc335afce98e2769bbf0d54bb7e02cb91f33ae Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:07:33 +0100 Subject: [PATCH 02/38] add useful error message for "server rejected WebSocket connection: HTTP 401" --- pyrit/prompt_target/websocket_copilot_target.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index f15994f56..da2b71a67 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -54,7 +54,6 @@ class WebSocketCopilotTarget(PromptTarget): """ # TODO: add more flexible auth, use puppeteer? https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L248 - # TODO: add useful message for: "Error during WebSocket communication: server rejected WebSocket connection: HTTP 401" SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? @@ -278,7 +277,8 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: list[Message]: A list containing the response from Copilot. Raises: - RuntimeError: If an error occurs during WebSocket communication. + websockets.exceptions.InvalidStatus: If the WebSocket connection fails. + RuntimeError: If any other error occurs during WebSocket communication. """ self._validate_request(message=message) request_piece = message.message_pieces[0] @@ -293,5 +293,13 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: return [response_entry] + except websockets.exceptions.InvalidStatus as e: + logger.error( + f"WebSocket connection failed: {str(e)}\n" + "Ensure the WEBSOCKET_URL environment variable is correct and valid." + " For more details about authentication, refer to the class documentation." + ) + raise e + except Exception as e: raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e From 107b715e920441ed512fc6086a2182dc55a18692 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:38:42 +0100 Subject: [PATCH 03/38] improve error handling and logging --- .../prompt_target/websocket_copilot_target.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index da2b71a67..13669bffb 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -10,6 +10,10 @@ import websockets +from pyrit.exceptions import ( + EmptyResponseException, + pyrit_target_retry, +) from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target import PromptTarget, limit_requests_per_minute @@ -266,6 +270,7 @@ def _validate_request(self, *, message: Message) -> None: raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") @limit_requests_per_minute + @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. @@ -277,16 +282,25 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: list[Message]: A list containing the response from Copilot. Raises: - websockets.exceptions.InvalidStatus: If the WebSocket connection fails. + EmptyResponseException: If the response from Copilot is empty. + InvalidStatus: If the WebSocket handshake fails with an HTTP status error. + WebSocketException: If the WebSocket connection fails. RuntimeError: If any other error occurs during WebSocket communication. """ self._validate_request(message=message) request_piece = message.message_pieces[0] + logger.info(f"Sending the following prompt to WebSocketCopilotTarget: {request_piece}") + try: prompt_text = request_piece.converted_value response_text = await self._connect_and_send(prompt_text) + if not response_text or not response_text.strip(): + logger.error("Empty response received from Copilot.") + raise EmptyResponseException(message="Copilot returned an empty response.") + logger.info(f"Received the following response from WebSocketCopilotTarget: {response_text[:100]}...") + response_entry = construct_response_from_request( request=request_piece, response_text_pieces=[response_text] ) @@ -299,7 +313,9 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: "Ensure the WEBSOCKET_URL environment variable is correct and valid." " For more details about authentication, refer to the class documentation." ) - raise e + raise + except websockets.exceptions.WebSocketException as e: + raise RuntimeError(f"WebSocket communication error: {str(e)}") from e except Exception as e: - raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e + raise RuntimeError(f"Unexpected error during WebSocket communication: {str(e)}") from e From 41013a22c66c5209e78fa7b62f77300757cf0862 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 21:51:53 +0100 Subject: [PATCH 04/38] enhance WebSocket URL validation --- .../prompt_target/websocket_copilot_target.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 13669bffb..87436004b 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -81,19 +81,30 @@ def __init__( model_name (str): The model name. Defaults to "copilot". Raises: - ValueError: If WebSocket URL is not provided as env variable. + ValueError: If WebSocket URL is not provided, is empty, or has invalid format. + ValueError: If required parameters are missing or empty in the WebSocket URL. """ self._websocket_url = os.getenv("WEBSOCKET_URL") - if not self._websocket_url: + if not self._websocket_url or self._websocket_url.strip() == "": raise ValueError("WebSocket URL must be provided through the WEBSOCKET_URL environment variable") - if not "ConversationId=" in self._websocket_url: - raise ValueError("`ConversationId` parameter not found in URL.") + if not self._websocket_url.startswith(("wss://", "ws://")): + raise ValueError( + "WebSocket URL must start with 'wss://' or 'ws://'. " + f"Received URL starting with: {self._websocket_url[:10]}" + ) + + if "ConversationId=" not in self._websocket_url: + raise ValueError("`ConversationId` parameter not found in WebSocket URL.") self._conversation_id = self._websocket_url.split("ConversationId=")[1].split("&")[0] + if not self._conversation_id: + raise ValueError("`ConversationId` parameter is empty in WebSocket URL.") - if not "X-SessionId=" in self._websocket_url: - raise ValueError("`X-SessionId` parameter not found in URL.") + if "X-SessionId=" not in self._websocket_url: + raise ValueError("`X-SessionId` parameter not found in WebSocket URL.") self._session_id = self._websocket_url.split("X-SessionId=")[1].split("&")[0] + if not self._session_id: + raise ValueError("`X-SessionId` parameter is empty in WebSocket URL.") super().__init__( verbose=verbose, From c8e7e831bd494146076b018ae8d63fd69f374973 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 22:57:35 +0100 Subject: [PATCH 05/38] small fix --- pyrit/prompt_target/websocket_copilot_target.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 87436004b..a2622e362 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -295,7 +295,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: Raises: EmptyResponseException: If the response from Copilot is empty. InvalidStatus: If the WebSocket handshake fails with an HTTP status error. - WebSocketException: If the WebSocket connection fails. RuntimeError: If any other error occurs during WebSocket communication. """ self._validate_request(message=message) @@ -326,7 +325,5 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ) raise - except websockets.exceptions.WebSocketException as e: - raise RuntimeError(f"WebSocket communication error: {str(e)}") from e except Exception as e: - raise RuntimeError(f"Unexpected error during WebSocket communication: {str(e)}") from e + raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e From af636b22f9792b1a125cb452963cb03d07df8a1f Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:01:21 +0100 Subject: [PATCH 06/38] fix --- pyrit/prompt_target/websocket_copilot_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index a2622e362..c3b3ebffc 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -135,7 +135,7 @@ def _parse_message(raw_message: str) -> tuple[int, str, dict]: """ try: # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding - message = message = raw_message.split("\x1e")[0] # record separator + message = raw_message.split("\x1e")[0] # record separator if not message: return (-1, "", {}) From c9ebcee1c923f969613bbbfe3077728fe7feaf2a Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:25:09 +0100 Subject: [PATCH 07/38] improve `_parse_message` --- .../prompt_target/websocket_copilot_target.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index c3b3ebffc..b8f86e076 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -123,7 +123,7 @@ def _dict_to_websocket(data: dict) -> str: return json.dumps(data, separators=(",", ":")) + "\x1e" @staticmethod - def _parse_message(raw_message: str) -> tuple[int, str, dict]: + def _parse_message(raw_message: str) -> tuple[int, str]: """ Extract actionable content from raw WebSocket frames. @@ -131,43 +131,38 @@ def _parse_message(raw_message: str) -> tuple[int, str, dict]: raw_message (str): The raw WebSocket message string. Returns: - tuple: (message_type, content_text, full_data) + tuple[int, str]: A tuple containing the message type and extracted content. """ try: # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding message = raw_message.split("\x1e")[0] # record separator if not message: - return (-1, "", {}) + return (-1, "") data = json.loads(message) msg_type = data.get("type", -1) - if msg_type == 6: # PING - return (6, "", data) + if msg_type in (6, 1): # PING/NEXT_DATA_FRAME + return (msg_type, "") if msg_type == 2: # LAST_DATA_FRAME item = data.get("item", {}) - if item: + if item and isinstance(item, dict): messages = item.get("messages", []) - if messages: + if messages and isinstance(messages, list): for msg in reversed(messages): - if msg.get("author") == "bot": + if isinstance(msg, dict) and msg.get("author") == "bot": text = msg.get("text", "") - if text: - return (2, text, data) - # TODO: maybe treat this as error? + if text and isinstance(text, str): + return (2, text) logger.warning("LAST_DATA_FRAME received but no parseable content found.") - return (2, "", data) + return (2, "") - if msg_type == 1: # NEXT_DATA_FRAME - # Streamed updates are not needed for this target - return (1, "", data) - - return (msg_type, "", data) + return (msg_type, "") except json.JSONDecodeError as e: logger.error(f"Failed to decode JSON message: {str(e)}") - return (-1, "", {}) + return (-1, "") def _build_prompt_message(self, prompt: str) -> dict: return { @@ -255,7 +250,13 @@ async def _connect_and_send(self, prompt: str) -> str: stop_polling = False while not stop_polling: response = await websocket.recv() - msg_type, content, data = self._parse_message(response) + + if response is None: + raise RuntimeError( + "WebSocket connection closed unexpectedly: received None from websocket.recv()" + ) + + msg_type, content = self._parse_message(response) if ( msg_type in (-1, 2) # UNKNOWN or LAST_DATA_FRAME From 99040d0e94238e7dbe61523e7bc23f4549d5f330 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:33:08 +0100 Subject: [PATCH 08/38] useful links --- pyrit/prompt_target/websocket_copilot_target.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index b8f86e076..06d848dbb 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -19,11 +19,9 @@ logger = logging.getLogger(__name__) -""" -Useful links: -https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py -https://labs.zenity.io/p/access-copilot-m365-terminal -""" +# Useful links: +# https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py +# https://labs.zenity.io/p/access-copilot-m365-terminal class CopilotMessageType(Enum): From 1f21ebf87e6c14a845f484e34bd5f97b1cce02f6 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Wed, 24 Dec 2025 23:19:24 +0100 Subject: [PATCH 09/38] add timeouts for responses and connection --- .../prompt_target/websocket_copilot_target.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 06d848dbb..3a51c3009 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio import json import logging import os @@ -59,9 +60,8 @@ class WebSocketCopilotTarget(PromptTarget): SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? - # TODO: implement timeouts and retries - MAX_WAIT_TIME_SECONDS: int = 300 - POLL_INTERVAL_MS: int = 2000 + RESPONSE_TIMEOUT_SECONDS: int = 60 + CONNECTION_TIMEOUT_SECONDS: int = 30 def __init__( self, @@ -238,7 +238,11 @@ async def _connect_and_send(self, prompt: str) -> str: inputs = [protocol_msg, prompt_dict] last_response = "" - async with websockets.connect(self._websocket_url) as websocket: + async with websockets.connect( + self._websocket_url, + open_timeout=self.CONNECTION_TIMEOUT_SECONDS, + close_timeout=self.CONNECTION_TIMEOUT_SECONDS, + ) as websocket: for input_msg in inputs: payload = self._dict_to_websocket(input_msg) is_user_input = input_msg.get("type") == 4 # USER_PROMPT @@ -247,7 +251,15 @@ async def _connect_and_send(self, prompt: str) -> str: stop_polling = False while not stop_polling: - response = await websocket.recv() + try: + response = await asyncio.wait_for( + websocket.recv(), + timeout=self.RECEIVE_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out waiting for Copilot response after {self.RECEIVE_TIMEOUT_SECONDS} seconds." + ) if response is None: raise RuntimeError( From 1806e7970f00977e241f49a70f50e9b999aa8c1b Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 15:28:10 +0100 Subject: [PATCH 10/38] start with tests --- .../target/test_websocket_copilot_target.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/unit/target/test_websocket_copilot_target.py diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py new file mode 100644 index 000000000..ab45dbb8f --- /dev/null +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from unittest.mock import patch + +import pytest + +from pyrit.prompt_target import WebSocketCopilotTarget + + +VALID_WEBSOCKET_URL = ( + "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" + "?ClientRequestId=test_client_request_id" + "&X-SessionId=test_session_id&token=abc123" + "&ConversationId=test_conversation_id" + "&access_token=test_access_token" + # "&variants=feature.test_feature_one,feature.test_feature_two" + # "&agent=web" + # "&scenario=OfficeWebIncludedCopilot" +) + + +@pytest.fixture +def mock_env_websocket_url(): + """Fixture to set the WEBSOCKET_URL environment variable.""" + with patch.dict(os.environ, {"WEBSOCKET_URL": VALID_WEBSOCKET_URL}): + yield + + +@pytest.mark.usefixtures("patch_central_database") +class TestWebSocketCopilotTargetInit: + def test_init_with_valid_wss_url(self, mock_env_websocket_url): + target = WebSocketCopilotTarget() + + assert target._websocket_url == VALID_WEBSOCKET_URL + assert target._conversation_id == "test_conversation_id" + assert target._session_id == "test_session_id" + assert target._model_name == "copilot" + + def test_init_with_missing_or_invalid_wss_url(self): + for env_vars in [{}, {"WEBSOCKET_URL": ""}, {"WEBSOCKET_URL": " "}]: + with patch.dict(os.environ, env_vars, clear=True): + with pytest.raises(ValueError, match="WebSocket URL must be provided"): + WebSocketCopilotTarget() + + for invalid_url in ["invalid_websocket_url", "ws://example.com", "https://example.com"]: + with patch.dict(os.environ, {"WEBSOCKET_URL": invalid_url}, clear=True): + with pytest.raises(ValueError, match="WebSocket URL must start with 'wss://'"): + WebSocketCopilotTarget() + + def test_init_with_missing_or_empty_required_params(self): + urls = [ + ("wss://example.com/?X-SessionId=session123", "`ConversationId` parameter not found"), + ("wss://example.com/?ConversationId=conv123", "`X-SessionId` parameter not found"), + ("wss://example.com/?ConversationId=&X-SessionId=session123", "`ConversationId` parameter is empty"), + ("wss://example.com/?ConversationId=conv123&X-SessionId=", "`X-SessionId` parameter is empty"), + ] + + for url, error_msg in urls: + with patch.dict(os.environ, {"WEBSOCKET_URL": url}, clear=True): + with pytest.raises(ValueError, match=error_msg): + WebSocketCopilotTarget() + + def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): + target = WebSocketCopilotTarget() + assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" From 3962bbbab536d45accd9c2a413215394be05d546 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 15:28:39 +0100 Subject: [PATCH 11/38] require `wss://` only --- pyrit/prompt_target/websocket_copilot_target.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 3a51c3009..ea99198a9 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -86,11 +86,8 @@ def __init__( if not self._websocket_url or self._websocket_url.strip() == "": raise ValueError("WebSocket URL must be provided through the WEBSOCKET_URL environment variable") - if not self._websocket_url.startswith(("wss://", "ws://")): - raise ValueError( - "WebSocket URL must start with 'wss://' or 'ws://'. " - f"Received URL starting with: {self._websocket_url[:10]}" - ) + if not self._websocket_url.startswith("wss://"): + raise ValueError(f"WebSocket URL must start with 'wss://'. Received: {self._websocket_url[:10]}") if "ConversationId=" not in self._websocket_url: raise ValueError("`ConversationId` parameter not found in WebSocket URL.") From 7588e8d1917dc5947c4f2bb4cff085cf34b2bd1d Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 15:41:24 +0100 Subject: [PATCH 12/38] add configurable response timeout --- pyrit/prompt_target/websocket_copilot_target.py | 10 ++++++++-- tests/unit/target/test_websocket_copilot_target.py | 9 +++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index ea99198a9..decd8c453 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -69,6 +69,7 @@ def __init__( verbose: bool = False, max_requests_per_minute: Optional[int] = None, model_name: str = "copilot", + response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, ) -> None: """ Initialize the WebSocketCopilotTarget. @@ -77,6 +78,7 @@ def __init__( verbose (bool): Enable verbose logging. Defaults to False. max_requests_per_minute (int, Optional): Maximum number of requests per minute. model_name (str): The model name. Defaults to "copilot". + response_timeout_seconds (int): Timeout for receiving responses in seconds. Defaults to 60s. Raises: ValueError: If WebSocket URL is not provided, is empty, or has invalid format. @@ -108,6 +110,10 @@ def __init__( model_name=model_name, ) + if response_timeout_seconds <= 0: + raise ValueError("response_timeout_seconds must be a positive integer.") + self._response_timeout_seconds = response_timeout_seconds + if self._verbose: logger.info(f"WebSocketCopilotTarget initialized with conversation_id: {self._conversation_id}") logger.info(f"Session ID: {self._session_id}") @@ -251,11 +257,11 @@ async def _connect_and_send(self, prompt: str) -> str: try: response = await asyncio.wait_for( websocket.recv(), - timeout=self.RECEIVE_TIMEOUT_SECONDS, + timeout=self._response_timeout_seconds, ) except asyncio.TimeoutError: raise TimeoutError( - f"Timed out waiting for Copilot response after {self.RECEIVE_TIMEOUT_SECONDS} seconds." + f"Timed out waiting for Copilot response after {self._response_timeout_seconds} seconds." ) if response is None: diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index ab45dbb8f..37fe2f345 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -65,3 +65,12 @@ def test_init_with_missing_or_empty_required_params(self): def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): target = WebSocketCopilotTarget() assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" + + def test_init_with_custom_response_timeout(self, mock_env_websocket_url): + target = WebSocketCopilotTarget(response_timeout_seconds=120) + assert target._response_timeout_seconds == 120 + + for invalid_timeout in [0, -10]: + with pytest.raises(ValueError, match="response_timeout_seconds must be a positive integer."): + WebSocketCopilotTarget(response_timeout_seconds=invalid_timeout) + From b98f6754d7e4d68091fddbacb31982f35284e0a1 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 15:43:27 +0100 Subject: [PATCH 13/38] fix --- tests/unit/target/test_websocket_copilot_target.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 37fe2f345..deae55ca2 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -30,7 +30,7 @@ def mock_env_websocket_url(): @pytest.mark.usefixtures("patch_central_database") class TestWebSocketCopilotTargetInit: - def test_init_with_valid_wss_url(self, mock_env_websocket_url): + def test_init_with_valid_wss_url(self): target = WebSocketCopilotTarget() assert target._websocket_url == VALID_WEBSOCKET_URL @@ -62,11 +62,11 @@ def test_init_with_missing_or_empty_required_params(self): with pytest.raises(ValueError, match=error_msg): WebSocketCopilotTarget() - def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): + def test_init_sets_endpoint_correctly(self): target = WebSocketCopilotTarget() assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" - def test_init_with_custom_response_timeout(self, mock_env_websocket_url): + def test_init_with_custom_response_timeout(self): target = WebSocketCopilotTarget(response_timeout_seconds=120) assert target._response_timeout_seconds == 120 From 0dab7bb0eefe1834bc7b28cf07b2e9bbffa9b82f Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 16:02:50 +0100 Subject: [PATCH 14/38] replace Enum with IntEnum and actually use it --- .../prompt_target/websocket_copilot_target.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index decd8c453..9af6f7fd1 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -6,7 +6,7 @@ import logging import os import uuid -from enum import Enum +from enum import IntEnum from typing import Optional import websockets @@ -25,7 +25,7 @@ # https://labs.zenity.io/p/access-copilot-m365-terminal -class CopilotMessageType(Enum): +class CopilotMessageType(IntEnum): """Enumeration for Copilot WebSocket message types.""" UNKNOWN = -1 @@ -124,7 +124,7 @@ def _dict_to_websocket(data: dict) -> str: return json.dumps(data, separators=(",", ":")) + "\x1e" @staticmethod - def _parse_message(raw_message: str) -> tuple[int, str]: + def _parse_message(raw_message: str) -> tuple[CopilotMessageType, str]: """ Extract actionable content from raw WebSocket frames. @@ -132,21 +132,21 @@ def _parse_message(raw_message: str) -> tuple[int, str]: raw_message (str): The raw WebSocket message string. Returns: - tuple[int, str]: A tuple containing the message type and extracted content. + tuple[CopilotMessageType, str]: A tuple containing the message type and extracted content. """ try: # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding message = raw_message.split("\x1e")[0] # record separator if not message: - return (-1, "") + return (CopilotMessageType.UNKNOWN, "") data = json.loads(message) - msg_type = data.get("type", -1) + msg_type = CopilotMessageType(data.get("type", -1)) - if msg_type in (6, 1): # PING/NEXT_DATA_FRAME + if msg_type in (CopilotMessageType.PING, CopilotMessageType.NEXT_DATA_FRAME): return (msg_type, "") - if msg_type == 2: # LAST_DATA_FRAME + if msg_type == CopilotMessageType.LAST_DATA_FRAME: item = data.get("item", {}) if item and isinstance(item, dict): messages = item.get("messages", []) @@ -155,15 +155,15 @@ def _parse_message(raw_message: str) -> tuple[int, str]: if isinstance(msg, dict) and msg.get("author") == "bot": text = msg.get("text", "") if text and isinstance(text, str): - return (2, text) + return (CopilotMessageType.LAST_DATA_FRAME, text) logger.warning("LAST_DATA_FRAME received but no parseable content found.") - return (2, "") + return (CopilotMessageType.LAST_DATA_FRAME, "") return (msg_type, "") except json.JSONDecodeError as e: logger.error(f"Failed to decode JSON message: {str(e)}") - return (-1, "") + return (CopilotMessageType.UNKNOWN, "") def _build_prompt_message(self, prompt: str) -> dict: return { @@ -231,7 +231,7 @@ def _build_prompt_message(self, prompt: str) -> dict: ], "invocationId": "0", # TODO: should be dynamic? "target": "chat", - "type": 4, + "type": CopilotMessageType.USER_PROMPT, } async def _connect_and_send(self, prompt: str) -> str: @@ -248,7 +248,7 @@ async def _connect_and_send(self, prompt: str) -> str: ) as websocket: for input_msg in inputs: payload = self._dict_to_websocket(input_msg) - is_user_input = input_msg.get("type") == 4 # USER_PROMPT + is_user_input = input_msg.get("type") == CopilotMessageType.USER_PROMPT await websocket.send(payload) @@ -272,15 +272,15 @@ async def _connect_and_send(self, prompt: str) -> str: msg_type, content = self._parse_message(response) if ( - msg_type in (-1, 2) # UNKNOWN or LAST_DATA_FRAME - or msg_type == 6 + msg_type in (CopilotMessageType.UNKNOWN, CopilotMessageType.LAST_DATA_FRAME) + or msg_type == CopilotMessageType.PING and not is_user_input ): stop_polling = True - if msg_type == 2: # LAST_DATA_FRAME - final response + if msg_type == CopilotMessageType.LAST_DATA_FRAME: last_response = content - elif msg_type == -1: # UNKNOWN/NONE + elif msg_type == CopilotMessageType.UNKNOWN: logger.debug("Received unknown or empty message type.") return last_response From c2df619c5727593068012ebb5eeea46ffeba42a4 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 16:18:35 +0100 Subject: [PATCH 15/38] test_dict_to_websocket_static_method --- .../target/test_websocket_copilot_target.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index deae55ca2..663ce5145 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -21,13 +21,6 @@ ) -@pytest.fixture -def mock_env_websocket_url(): - """Fixture to set the WEBSOCKET_URL environment variable.""" - with patch.dict(os.environ, {"WEBSOCKET_URL": VALID_WEBSOCKET_URL}): - yield - - @pytest.mark.usefixtures("patch_central_database") class TestWebSocketCopilotTargetInit: def test_init_with_valid_wss_url(self): @@ -74,3 +67,16 @@ def test_init_with_custom_response_timeout(self): with pytest.raises(ValueError, match="response_timeout_seconds must be a positive integer."): WebSocketCopilotTarget(response_timeout_seconds=invalid_timeout) + +@pytest.mark.parametrize( + "data,expected", + [ + ({"key": "value"}, '{"key":"value"}\x1e'), + ({"protocol": "json", "version": 1}, '{"protocol":"json","version":1}\x1e'), + ({"outer": {"inner": "value"}}, '{"outer":{"inner":"value"}}\x1e'), + ({"items": [1, 2, 3]}, '{"items":[1,2,3]}\x1e'), + ], +) +def test_dict_to_websocket_static_method(data, expected): + result = WebSocketCopilotTarget._dict_to_websocket(data) + assert result == expected From 18fd2387d0cb9c06df15a2bd0f6a600cc3bb675e Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 16:20:56 +0100 Subject: [PATCH 16/38] fix --- tests/unit/target/test_websocket_copilot_target.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 663ce5145..aadbc1ed5 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -21,9 +21,15 @@ ) +@pytest.fixture +def mock_env_websocket_url(): + with patch.dict(os.environ, {"WEBSOCKET_URL": VALID_WEBSOCKET_URL}): + yield + + @pytest.mark.usefixtures("patch_central_database") class TestWebSocketCopilotTargetInit: - def test_init_with_valid_wss_url(self): + def test_init_with_valid_wss_url(self, mock_env_websocket_url): target = WebSocketCopilotTarget() assert target._websocket_url == VALID_WEBSOCKET_URL @@ -55,11 +61,11 @@ def test_init_with_missing_or_empty_required_params(self): with pytest.raises(ValueError, match=error_msg): WebSocketCopilotTarget() - def test_init_sets_endpoint_correctly(self): + def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): target = WebSocketCopilotTarget() assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" - def test_init_with_custom_response_timeout(self): + def test_init_with_custom_response_timeout(self, mock_env_websocket_url): target = WebSocketCopilotTarget(response_timeout_seconds=120) assert target._response_timeout_seconds == 120 From 73b07d00272adaa089acf888fb57bb954ffb487d Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 18:04:05 +0100 Subject: [PATCH 17/38] Refactor WebSocket message parser to handle multiple frames per message - Rename _parse_message() to _parse_raw_message() - Split on record separator (\x1e) and processes all frames, not just first - Add FINAL_DATA_FRAME (type 3) enum value for completion signals - Extract bot message parsing logic into lambda for reusability - Fixes stop condition to handle FINAL_DATA_FRAME and remove flawed "is_user_input and PING" logic --- .../prompt_target/websocket_copilot_target.py | 111 ++++++++++-------- 1 file changed, 64 insertions(+), 47 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 9af6f7fd1..c55ea051a 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -31,6 +31,7 @@ class CopilotMessageType(IntEnum): UNKNOWN = -1 NEXT_DATA_FRAME = 1 # streaming Copilot responses LAST_DATA_FRAME = 2 # the last data frame with final content + FINAL_DATA_FRAME = 3 # the final data frame indicating completion USER_PROMPT = 4 PING = 6 @@ -124,46 +125,63 @@ def _dict_to_websocket(data: dict) -> str: return json.dumps(data, separators=(",", ":")) + "\x1e" @staticmethod - def _parse_message(raw_message: str) -> tuple[CopilotMessageType, str]: + def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: """ - Extract actionable content from raw WebSocket frames. + Extract actionable content from a raw WebSocket message. + Returns more than one JSON message if multiple are found. Args: - raw_message (str): The raw WebSocket message string. + message (str): The raw WebSocket message string. Returns: - tuple[CopilotMessageType, str]: A tuple containing the message type and extracted content. + list[tuple[CopilotMessageType, str]]: A list of tuples where each tuple contains + message type and extracted content. """ - try: - # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding - message = raw_message.split("\x1e")[0] # record separator - if not message: - return (CopilotMessageType.UNKNOWN, "") - - data = json.loads(message) - msg_type = CopilotMessageType(data.get("type", -1)) - - if msg_type in (CopilotMessageType.PING, CopilotMessageType.NEXT_DATA_FRAME): - return (msg_type, "") - - if msg_type == CopilotMessageType.LAST_DATA_FRAME: - item = data.get("item", {}) - if item and isinstance(item, dict): - messages = item.get("messages", []) - if messages and isinstance(messages, list): - for msg in reversed(messages): - if isinstance(msg, dict) and msg.get("author") == "bot": - text = msg.get("text", "") - if text and isinstance(text, str): - return (CopilotMessageType.LAST_DATA_FRAME, text) - logger.warning("LAST_DATA_FRAME received but no parseable content found.") - return (CopilotMessageType.LAST_DATA_FRAME, "") - - return (msg_type, "") - - except json.JSONDecodeError as e: - logger.error(f"Failed to decode JSON message: {str(e)}") - return (CopilotMessageType.UNKNOWN, "") + # Find the last chat message with text content + extract_bot_message = lambda data: next( + ( + msg.get("text", "") + for msg in reversed(data.get("item", {}).get("messages", [])) + if isinstance(msg, dict) and msg.get("author") == "bot" and msg.get("text") + ), + "", + ) + + results: list[tuple[CopilotMessageType, str]] = [] + + # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding + messages = message.split("\x1e") # record separator + + for message in messages: + if not message or not message.strip(): + continue + + try: + data = json.loads(message) + msg_type = CopilotMessageType(data.get("type", -1)) + + if msg_type in ( + CopilotMessageType.PING, + CopilotMessageType.NEXT_DATA_FRAME, + CopilotMessageType.FINAL_DATA_FRAME, + ): + results.append((msg_type, "")) + continue + + if msg_type == CopilotMessageType.LAST_DATA_FRAME: + bot_text = extract_bot_message(data) + if not bot_text: + logger.warning("LAST_DATA_FRAME received but no parseable content found.") + results.append((CopilotMessageType.LAST_DATA_FRAME, bot_text)) + continue + + results.append((msg_type, "")) + + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON message: {str(e)}") + results.append((CopilotMessageType.UNKNOWN, "")) + + return results if results else [(CopilotMessageType.UNKNOWN, "")] def _build_prompt_message(self, prompt: str) -> dict: return { @@ -248,8 +266,6 @@ async def _connect_and_send(self, prompt: str) -> str: ) as websocket: for input_msg in inputs: payload = self._dict_to_websocket(input_msg) - is_user_input = input_msg.get("type") == CopilotMessageType.USER_PROMPT - await websocket.send(payload) stop_polling = False @@ -269,19 +285,20 @@ async def _connect_and_send(self, prompt: str) -> str: "WebSocket connection closed unexpectedly: received None from websocket.recv()" ) - msg_type, content = self._parse_message(response) + parsed_messages = self._parse_raw_message(response) - if ( - msg_type in (CopilotMessageType.UNKNOWN, CopilotMessageType.LAST_DATA_FRAME) - or msg_type == CopilotMessageType.PING - and not is_user_input - ): - stop_polling = True + for msg_type, content in parsed_messages: + if msg_type in ( + CopilotMessageType.UNKNOWN, + CopilotMessageType.LAST_DATA_FRAME, + CopilotMessageType.FINAL_DATA_FRAME, + ): + stop_polling = True - if msg_type == CopilotMessageType.LAST_DATA_FRAME: - last_response = content - elif msg_type == CopilotMessageType.UNKNOWN: - logger.debug("Received unknown or empty message type.") + if msg_type == CopilotMessageType.LAST_DATA_FRAME: + last_response = content + elif msg_type == CopilotMessageType.UNKNOWN: + logger.debug("Received unknown or empty message type.") return last_response From 9a8a878ebc45c9fc7c7f251a449ca401f6246b1c Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 18:10:05 +0100 Subject: [PATCH 18/38] rename message types in the enum --- .../prompt_target/websocket_copilot_target.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index c55ea051a..5956154d5 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -29,9 +29,9 @@ class CopilotMessageType(IntEnum): """Enumeration for Copilot WebSocket message types.""" UNKNOWN = -1 - NEXT_DATA_FRAME = 1 # streaming Copilot responses - LAST_DATA_FRAME = 2 # the last data frame with final content - FINAL_DATA_FRAME = 3 # the final data frame indicating completion + PARTIAL_RESPONSE = 1 + FINAL_CONTENT = 2 + STREAM_END = 3 USER_PROMPT = 4 PING = 6 @@ -162,17 +162,17 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: if msg_type in ( CopilotMessageType.PING, - CopilotMessageType.NEXT_DATA_FRAME, - CopilotMessageType.FINAL_DATA_FRAME, + CopilotMessageType.PARTIAL_RESPONSE, + CopilotMessageType.STREAM_END, ): results.append((msg_type, "")) continue - if msg_type == CopilotMessageType.LAST_DATA_FRAME: + if msg_type == CopilotMessageType.FINAL_CONTENT: bot_text = extract_bot_message(data) if not bot_text: - logger.warning("LAST_DATA_FRAME received but no parseable content found.") - results.append((CopilotMessageType.LAST_DATA_FRAME, bot_text)) + logger.warning("FINAL_CONTENT received but no parseable content found.") + results.append((CopilotMessageType.FINAL_CONTENT, bot_text)) continue results.append((msg_type, "")) @@ -290,12 +290,12 @@ async def _connect_and_send(self, prompt: str) -> str: for msg_type, content in parsed_messages: if msg_type in ( CopilotMessageType.UNKNOWN, - CopilotMessageType.LAST_DATA_FRAME, - CopilotMessageType.FINAL_DATA_FRAME, + CopilotMessageType.FINAL_CONTENT, + CopilotMessageType.STREAM_END, ): stop_polling = True - if msg_type == CopilotMessageType.LAST_DATA_FRAME: + if msg_type == CopilotMessageType.FINAL_CONTENT: last_response = content elif msg_type == CopilotMessageType.UNKNOWN: logger.debug("Received unknown or empty message type.") From 4d3c15dd1c575a5f48d272847b2d9b1106c4c7bc Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 18:26:34 +0100 Subject: [PATCH 19/38] add raw WebSocket messages for testing --- tests/unit/target/test_websocket_copilot_target.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index aadbc1ed5..a1588b413 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -86,3 +86,14 @@ def test_init_with_custom_response_timeout(self, mock_env_websocket_url): def test_dict_to_websocket_static_method(data, expected): result = WebSocketCopilotTarget._dict_to_websocket(data) assert result == expected + + +RAW_WEBSOCKET_MESSAGES = [ + "{}\x1e", + '{"type":6}\x1e', + '{"type":1,"target":"update","arguments":[{"messages":[{"text":"Apple","author":"bot","responseIdentifier":"Default"}]}]}\x1e', + '{"type":3,"invocationId":"0"}\x1e', + '{"type":2,"invocationId":"0","item":{"messages":[{"text":"Name a fruit","author":"user"},{"text":"Apple. 🍎 \n\nWould you like me to list more fruits or give you some interesting facts about apples?","turnState":"Completed","author":"bot","turnCount":1}],"firstNewMessageIndex":1,"conversationId":"conversationId","requestId":"requestId","result":{"value":"Success","message":"Apple. 🍎 \n\nWould you like me to list more fruits or give you some interesting facts about apples?","serviceVersion":"1.0.03273.12483"}}}\x1e', +] + +# TODO: add tests for _parse_raw_message From b095d742476ae416189dc177815a786a42543e91 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 19:43:20 +0100 Subject: [PATCH 20/38] remove emojis --- tests/unit/target/test_websocket_copilot_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index a1588b413..50b6681fc 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -93,7 +93,7 @@ def test_dict_to_websocket_static_method(data, expected): '{"type":6}\x1e', '{"type":1,"target":"update","arguments":[{"messages":[{"text":"Apple","author":"bot","responseIdentifier":"Default"}]}]}\x1e', '{"type":3,"invocationId":"0"}\x1e', - '{"type":2,"invocationId":"0","item":{"messages":[{"text":"Name a fruit","author":"user"},{"text":"Apple. 🍎 \n\nWould you like me to list more fruits or give you some interesting facts about apples?","turnState":"Completed","author":"bot","turnCount":1}],"firstNewMessageIndex":1,"conversationId":"conversationId","requestId":"requestId","result":{"value":"Success","message":"Apple. 🍎 \n\nWould you like me to list more fruits or give you some interesting facts about apples?","serviceVersion":"1.0.03273.12483"}}}\x1e', + '{"type":2,"invocationId":"0","item":{"messages":[{"text":"Name a fruit","author":"user"},{"text":"Apple. \n\nWould you like me to list more fruits or give you some interesting facts about apples?","turnState":"Completed","author":"bot","turnCount":1}],"firstNewMessageIndex":1,"conversationId":"conversationId","requestId":"requestId","result":{"value":"Success","message":"Apple. \n\nWould you like me to list more fruits or give you some interesting facts about apples?","serviceVersion":"1.0.03273.12483"}}}\x1e', ] # TODO: add tests for _parse_raw_message From 38e686827a05bb9fbb20ad03b09a9dd13da557ec Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 20:12:56 +0100 Subject: [PATCH 21/38] simpler way to get the final result --- pyrit/prompt_target/websocket_copilot_target.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 5956154d5..47631f694 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -137,16 +137,6 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: list[tuple[CopilotMessageType, str]]: A list of tuples where each tuple contains message type and extracted content. """ - # Find the last chat message with text content - extract_bot_message = lambda data: next( - ( - msg.get("text", "") - for msg in reversed(data.get("item", {}).get("messages", [])) - if isinstance(msg, dict) and msg.get("author") == "bot" and msg.get("text") - ), - "", - ) - results: list[tuple[CopilotMessageType, str]] = [] # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding @@ -169,8 +159,9 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: continue if msg_type == CopilotMessageType.FINAL_CONTENT: - bot_text = extract_bot_message(data) + bot_text = data.get("item", {}).get("result", {}).get("message", "") if not bot_text: + # In this case, EmptyResponseException will be raised anyway logger.warning("FINAL_CONTENT received but no parseable content found.") results.append((CopilotMessageType.FINAL_CONTENT, bot_text)) continue From 2430dbecba63b95a34c79efcf161240cb894b4b7 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 20:27:01 +0100 Subject: [PATCH 22/38] log full raw message when no parseable content found --- pyrit/prompt_target/websocket_copilot_target.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 47631f694..fae6fef35 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -163,6 +163,7 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: if not bot_text: # In this case, EmptyResponseException will be raised anyway logger.warning("FINAL_CONTENT received but no parseable content found.") + logger.debug(f"Full raw message: {message}") results.append((CopilotMessageType.FINAL_CONTENT, bot_text)) continue From 5b2c54a47bc8e62e2a1c997ba579a9bd5771a0e9 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 21:50:01 +0100 Subject: [PATCH 23/38] _value2member_map_ --- pyrit/prompt_target/websocket_copilot_target.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index fae6fef35..58a380ecd 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -148,7 +148,7 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: try: data = json.loads(message) - msg_type = CopilotMessageType(data.get("type", -1)) + msg_type = CopilotMessageType._value2member_map_.get(data.get("type", -1), CopilotMessageType.UNKNOWN) if msg_type in ( CopilotMessageType.PING, From 4a7a7b8bd922f612cb26bf3c2d2ded97c9f33ad9 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 21:50:55 +0100 Subject: [PATCH 24/38] TestParseRawMessage --- .../target/test_websocket_copilot_target.py | 71 ++++++++++++++++--- 1 file changed, 62 insertions(+), 9 deletions(-) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 50b6681fc..661cede37 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -88,12 +88,65 @@ def test_dict_to_websocket_static_method(data, expected): assert result == expected -RAW_WEBSOCKET_MESSAGES = [ - "{}\x1e", - '{"type":6}\x1e', - '{"type":1,"target":"update","arguments":[{"messages":[{"text":"Apple","author":"bot","responseIdentifier":"Default"}]}]}\x1e', - '{"type":3,"invocationId":"0"}\x1e', - '{"type":2,"invocationId":"0","item":{"messages":[{"text":"Name a fruit","author":"user"},{"text":"Apple. \n\nWould you like me to list more fruits or give you some interesting facts about apples?","turnState":"Completed","author":"bot","turnCount":1}],"firstNewMessageIndex":1,"conversationId":"conversationId","requestId":"requestId","result":{"value":"Success","message":"Apple. \n\nWould you like me to list more fruits or give you some interesting facts about apples?","serviceVersion":"1.0.03273.12483"}}}\x1e', -] - -# TODO: add tests for _parse_raw_message +class TestParseRawMessage: + from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + @pytest.mark.parametrize( + "message,expected_types,expected_content", + [ + ("", [CopilotMessageType.UNKNOWN], [""]), + (" \n\t ", [CopilotMessageType.UNKNOWN], [""]), + ("{}\x1e", [CopilotMessageType.UNKNOWN], [""]), + ('{"type":6}\x1e', [CopilotMessageType.PING], [""]), + ( + '{"type":1,"target":"update","arguments":[{"messages":[{"text":"Partial","author":"bot"}]}]}\x1e', + [CopilotMessageType.PARTIAL_RESPONSE], + [""], + ), + ( + '{"type":2,"item":{"result":{"message":"Final."}}}\x1e{"type":3,"invocationId":"0"}\x1e', + [CopilotMessageType.FINAL_CONTENT, CopilotMessageType.STREAM_END], + [ + "Final.", + "", + ], + ), + ], + ) + def test_parse_raw_message_with_valid_data(self, message, expected_types, expected_content): + result = WebSocketCopilotTarget._parse_raw_message(message) + + assert len(result) == len(expected_types) + for i, expected_type in enumerate(expected_types): + assert result[i][0] == expected_type + assert result[i][1] == expected_content[i] + + def test_parse_final_message_without_content(self): + from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + with patch("pyrit.prompt_target.websocket_copilot_target.logger") as mock_logger: + message = '{"type":2,"invocationId":"0"}\x1e' + result = WebSocketCopilotTarget._parse_raw_message(message) + + assert len(result) == 1 + assert result[0][0] == CopilotMessageType.FINAL_CONTENT + assert result[0][1] == "" + + mock_logger.warning.assert_called_with("FINAL_CONTENT received but no parseable content found.") + mock_logger.debug.assert_called_with(f"Full raw message: {message[:-1]}") + + @pytest.mark.parametrize( + "message", + [ + '{"type":99,"data":"unknown"}\x1e', + '{"data":"no type field"}\x1e', + '{"invalid json structure\x1e', + ], + ) + def test_parse_unknown_or_invalid_messages(self, message): + from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + result = WebSocketCopilotTarget._parse_raw_message(message) + assert len(result) == 1 + assert result[0][0] == CopilotMessageType.UNKNOWN + assert result[0][1] == "" From acb0a6da8f6c162f5887819316542c356c70e167 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 23:30:56 +0100 Subject: [PATCH 25/38] test fix --- tests/unit/target/test_websocket_copilot_target.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 661cede37..0b4e713e1 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -10,7 +10,7 @@ VALID_WEBSOCKET_URL = ( - "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" + "wss://substrate.office.com/m365Copilot/Chathub/test_object_id@test_tenant_id" "?ClientRequestId=test_client_request_id" "&X-SessionId=test_session_id&token=abc123" "&ConversationId=test_conversation_id" @@ -63,7 +63,7 @@ def test_init_with_missing_or_empty_required_params(self): def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): target = WebSocketCopilotTarget() - assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_chat_id" + assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_object_id@test_tenant_id" def test_init_with_custom_response_timeout(self, mock_env_websocket_url): target = WebSocketCopilotTarget(response_timeout_seconds=120) From 276290f7dc309731fb1cba38c14132026833ee51 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Thu, 25 Dec 2025 23:32:07 +0100 Subject: [PATCH 26/38] TODO: use msal for auth --- .../prompt_target/websocket_copilot_target.py | 102 ++++++++---------- 1 file changed, 46 insertions(+), 56 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 58a380ecd..7f6137535 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -57,8 +57,6 @@ class WebSocketCopilotTarget(PromptTarget): Only works with licensed Microsoft 365 Copilot. The free Copilot version is not compatible. """ - # TODO: add more flexible auth, use puppeteer? https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L248 - SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? RESPONSE_TIMEOUT_SECONDS: int = 60 @@ -179,67 +177,59 @@ def _build_prompt_message(self, prompt: str) -> dict: return { "arguments": [ { - "source": "officeweb", # TODO: support 'teamshub' as well - # TODO: not sure whether to uuid.uuid4() or use a static like it's done in power-pwn - # https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L156 - "clientCorrelationId": str(uuid.uuid4()), - "sessionId": self._session_id, - "optionsSets": [ - "enterprise_flux_web", - "enterprise_flux_work", - "enable_request_response_interstitials", - "enterprise_flux_image_v1", - "enterprise_toolbox_with_skdsstore", - "enterprise_toolbox_with_skdsstore_search_message_extensions", - "enable_ME_auth_interstitial", - "skdsstorethirdparty", - "enable_confirmation_interstitial", - "enable_plugin_auth_interstitial", - "enable_response_action_processing", - "enterprise_flux_work_gptv", - "enterprise_flux_work_code_interpreter", - "enable_batch_token_processing", - ], - "options": {}, - "allowedMessageTypes": [ - "Chat", - "Suggestion", - "InternalSearchQuery", - "InternalSearchResult", - "Disengaged", - "InternalLoaderMessage", - "RenderCardRequest", - "AdsQuery", - "SemanticSerp", - "GenerateContentQuery", - "SearchQuery", - "ConfirmationCard", - "AuthError", - "DeveloperLogs", - ], - "sliceIds": [], - # TODO: enable using agents https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L192 - "threadLevelGptId": {}, - "conversationId": self._conversation_id, - "traceId": str(uuid.uuid4()).replace("-", ""), # TODO: same case as clientCorrelationId - "isStartOfSession": 0, - "productThreadType": "Office", - "clientInfo": {"clientPlatform": "web"}, + # TODO: use msal for auth, then set these fields properly, as with current approach they are not really needed + # "source": "officeweb", + # "clientCorrelationId": str(uuid.uuid4()), + # "sessionId": self._session_id, + # "optionsSets": [ + # "enterprise_flux_web", + # "enterprise_flux_work", + # "enable_request_response_interstitials", + # "enterprise_flux_image_v1", + # "enterprise_toolbox_with_skdsstore", + # "enterprise_toolbox_with_skdsstore_search_message_extensions", + # "enable_ME_auth_interstitial", + # "skdsstorethirdparty", + # "enable_confirmation_interstitial", + # "enable_plugin_auth_interstitial", + # "enable_response_action_processing", + # "enterprise_flux_work_gptv", + # "enterprise_flux_work_code_interpreter", + # "enable_batch_token_processing", + # ], + # "options": {}, + # "allowedMessageTypes": [ + # "Chat", + # "Suggestion", + # "InternalSearchQuery", + # "InternalSearchResult", + # "Disengaged", + # "InternalLoaderMessage", + # "RenderCardRequest", + # "AdsQuery", + # "SemanticSerp", + # "GenerateContentQuery", + # "SearchQuery", + # "ConfirmationCard", + # "AuthError", + # "DeveloperLogs", + # ], + # "sliceIds": [], + # "threadLevelGptId": {}, + # "conversationId": self._conversation_id, + # "traceId": str(uuid.uuid4()).replace("-", ""), + # "isStartOfSession": 0, + # "productThreadType": "Office", + # "clientInfo": {"clientPlatform": "web"}, "message": { "author": "user", - "inputMethod": "Keyboard", "text": prompt, - "entityAnnotationTypes": ["People", "File", "Event", "Email", "TeamsMessage"], "requestId": str(uuid.uuid4()).replace("-", ""), - "locationInfo": {"timeZoneOffset": 0, "timeZone": "UTC"}, - "locale": "en-US", - "messageType": "Chat", - "experienceType": "Default", }, - "plugins": [], # TODO: support enabling some plugins? + # "plugins": [], } ], - "invocationId": "0", # TODO: should be dynamic? + "invocationId": "0", "target": "chat", "type": CopilotMessageType.USER_PROMPT, } From ded56c6e7f2ea4f2e2e2f5eec7bc255383659237 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Fri, 26 Dec 2025 10:09:30 +0100 Subject: [PATCH 27/38] add device code flow authentication method --- pyrit/auth/azure_auth.py | 47 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 34d31f667..b854963a3 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -5,6 +5,7 @@ import time from typing import Callable, Union from urllib.parse import urlparse +import textwrap import msal from azure.core.credentials import AccessToken @@ -179,6 +180,52 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str): raise +def get_access_token_from_device_code( + *, client_id: str, scope: str, authority: str = "https://login.microsoftonline.com/common" +): + """ + Use Device Code Flow to authenticate. User will be prompted to visit a URL and enter a code. + This method is useful for headless environments or when interactive browser login is not available. + + Args: + client_id (str): The client ID of the service. + scope (str): The scope to request. + authority (str): The MSAL authority URL. Defaults to common tenant. + + Returns: + str: Authentication token. + + Raises: + RuntimeError: If device flow initiation or authentication fails. + """ + try: + app = msal.PublicClientApplication(client_id=client_id, authority=authority) + flow = app.initiate_device_flow(scopes=[scope]) + + if "user_code" not in flow: + error_msg = flow.get("error_description", "Unknown error") + raise RuntimeError(f"Failed to initiate device flow: {error_msg}") + + print("\n" + "=" * 80) + print(" DEVICE CODE AUTHENTICATION".center(80)) + print("=" * 80) + print("\n" + textwrap.fill(flow["message"], width=76, initial_indent=" ", subsequent_indent=" ")) + print("\n ⏳ Waiting for authentication to complete...") + print("=" * 80 + "\n") + + result = app.acquire_token_by_device_flow(flow) + + if "access_token" not in result: + error = result.get("error", "Unknown error") + error_desc = result.get("error_description", "") + raise RuntimeError(f"Authentication failed: {error} - {error_desc}") + + return result["access_token"] + except Exception as e: + logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") + raise + + def get_access_token_from_interactive_login(scope: str) -> str: """ Connect to an OpenAI endpoint with an interactive login from Azure. A browser window will From 558f48fa62bd1532be166b0dfe5c02cc0afa2f37 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Fri, 26 Dec 2025 22:56:14 +0100 Subject: [PATCH 28/38] Revert "TODO: use msal for auth" -- as we need browser automation anyway because of `enterprise-prod-first-party-app-policy` This reverts commit 276290f7dc309731fb1cba38c14132026833ee51. --- .../prompt_target/websocket_copilot_target.py | 102 ++++++++++-------- 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 7f6137535..58a380ecd 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -57,6 +57,8 @@ class WebSocketCopilotTarget(PromptTarget): Only works with licensed Microsoft 365 Copilot. The free Copilot version is not compatible. """ + # TODO: add more flexible auth, use puppeteer? https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L248 + SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? RESPONSE_TIMEOUT_SECONDS: int = 60 @@ -177,59 +179,67 @@ def _build_prompt_message(self, prompt: str) -> dict: return { "arguments": [ { - # TODO: use msal for auth, then set these fields properly, as with current approach they are not really needed - # "source": "officeweb", - # "clientCorrelationId": str(uuid.uuid4()), - # "sessionId": self._session_id, - # "optionsSets": [ - # "enterprise_flux_web", - # "enterprise_flux_work", - # "enable_request_response_interstitials", - # "enterprise_flux_image_v1", - # "enterprise_toolbox_with_skdsstore", - # "enterprise_toolbox_with_skdsstore_search_message_extensions", - # "enable_ME_auth_interstitial", - # "skdsstorethirdparty", - # "enable_confirmation_interstitial", - # "enable_plugin_auth_interstitial", - # "enable_response_action_processing", - # "enterprise_flux_work_gptv", - # "enterprise_flux_work_code_interpreter", - # "enable_batch_token_processing", - # ], - # "options": {}, - # "allowedMessageTypes": [ - # "Chat", - # "Suggestion", - # "InternalSearchQuery", - # "InternalSearchResult", - # "Disengaged", - # "InternalLoaderMessage", - # "RenderCardRequest", - # "AdsQuery", - # "SemanticSerp", - # "GenerateContentQuery", - # "SearchQuery", - # "ConfirmationCard", - # "AuthError", - # "DeveloperLogs", - # ], - # "sliceIds": [], - # "threadLevelGptId": {}, - # "conversationId": self._conversation_id, - # "traceId": str(uuid.uuid4()).replace("-", ""), - # "isStartOfSession": 0, - # "productThreadType": "Office", - # "clientInfo": {"clientPlatform": "web"}, + "source": "officeweb", # TODO: support 'teamshub' as well + # TODO: not sure whether to uuid.uuid4() or use a static like it's done in power-pwn + # https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L156 + "clientCorrelationId": str(uuid.uuid4()), + "sessionId": self._session_id, + "optionsSets": [ + "enterprise_flux_web", + "enterprise_flux_work", + "enable_request_response_interstitials", + "enterprise_flux_image_v1", + "enterprise_toolbox_with_skdsstore", + "enterprise_toolbox_with_skdsstore_search_message_extensions", + "enable_ME_auth_interstitial", + "skdsstorethirdparty", + "enable_confirmation_interstitial", + "enable_plugin_auth_interstitial", + "enable_response_action_processing", + "enterprise_flux_work_gptv", + "enterprise_flux_work_code_interpreter", + "enable_batch_token_processing", + ], + "options": {}, + "allowedMessageTypes": [ + "Chat", + "Suggestion", + "InternalSearchQuery", + "InternalSearchResult", + "Disengaged", + "InternalLoaderMessage", + "RenderCardRequest", + "AdsQuery", + "SemanticSerp", + "GenerateContentQuery", + "SearchQuery", + "ConfirmationCard", + "AuthError", + "DeveloperLogs", + ], + "sliceIds": [], + # TODO: enable using agents https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L192 + "threadLevelGptId": {}, + "conversationId": self._conversation_id, + "traceId": str(uuid.uuid4()).replace("-", ""), # TODO: same case as clientCorrelationId + "isStartOfSession": 0, + "productThreadType": "Office", + "clientInfo": {"clientPlatform": "web"}, "message": { "author": "user", + "inputMethod": "Keyboard", "text": prompt, + "entityAnnotationTypes": ["People", "File", "Event", "Email", "TeamsMessage"], "requestId": str(uuid.uuid4()).replace("-", ""), + "locationInfo": {"timeZoneOffset": 0, "timeZone": "UTC"}, + "locale": "en-US", + "messageType": "Chat", + "experienceType": "Default", }, - # "plugins": [], + "plugins": [], # TODO: support enabling some plugins? } ], - "invocationId": "0", + "invocationId": "0", # TODO: should be dynamic? "target": "chat", "type": CopilotMessageType.USER_PROMPT, } From 02e3a4e3bde7a05aa1b0d9278a080cd3a4463254 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Fri, 26 Dec 2025 22:58:59 +0100 Subject: [PATCH 29/38] Revert "add device code flow authentication method" This reverts commit ded56c6e7f2ea4f2e2e2f5eec7bc255383659237. --- pyrit/auth/azure_auth.py | 47 ---------------------------------------- 1 file changed, 47 deletions(-) diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index b854963a3..34d31f667 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -5,7 +5,6 @@ import time from typing import Callable, Union from urllib.parse import urlparse -import textwrap import msal from azure.core.credentials import AccessToken @@ -180,52 +179,6 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str): raise -def get_access_token_from_device_code( - *, client_id: str, scope: str, authority: str = "https://login.microsoftonline.com/common" -): - """ - Use Device Code Flow to authenticate. User will be prompted to visit a URL and enter a code. - This method is useful for headless environments or when interactive browser login is not available. - - Args: - client_id (str): The client ID of the service. - scope (str): The scope to request. - authority (str): The MSAL authority URL. Defaults to common tenant. - - Returns: - str: Authentication token. - - Raises: - RuntimeError: If device flow initiation or authentication fails. - """ - try: - app = msal.PublicClientApplication(client_id=client_id, authority=authority) - flow = app.initiate_device_flow(scopes=[scope]) - - if "user_code" not in flow: - error_msg = flow.get("error_description", "Unknown error") - raise RuntimeError(f"Failed to initiate device flow: {error_msg}") - - print("\n" + "=" * 80) - print(" DEVICE CODE AUTHENTICATION".center(80)) - print("=" * 80) - print("\n" + textwrap.fill(flow["message"], width=76, initial_indent=" ", subsequent_indent=" ")) - print("\n ⏳ Waiting for authentication to complete...") - print("=" * 80 + "\n") - - result = app.acquire_token_by_device_flow(flow) - - if "access_token" not in result: - error = result.get("error", "Unknown error") - error_desc = result.get("error_description", "") - raise RuntimeError(f"Authentication failed: {error} - {error_desc}") - - return result["access_token"] - except Exception as e: - logger.error(f"Failed to obtain token for '{scope}' with client ID '{client_id}': {e}") - raise - - def get_access_token_from_interactive_login(scope: str) -> str: """ Connect to an OpenAI endpoint with an interactive login from Azure. A browser window will From 0a9ee34e2efc38052e3091ee76f5fbf52cdac7f5 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Sat, 27 Dec 2025 23:33:20 +0100 Subject: [PATCH 30/38] add Playwright-based way of getting sydney access token --- pyrit/auth/copilot_authenticator.py | 200 ++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 pyrit/auth/copilot_authenticator.py diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py new file mode 100644 index 000000000..aeb6d2d04 --- /dev/null +++ b/pyrit/auth/copilot_authenticator.py @@ -0,0 +1,200 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import asyncio +from typing import Optional + +import json +import re + +from pyrit.auth.authenticator import Authenticator + +logger = logging.getLogger(__name__) + + +class CopilotAuthenticator(Authenticator): + """ + Playwright-based authenticator for Microsoft Copilot. + + You need to have playwright installed and set up: ``pip install playwright && playwright install chromium``. + """ + + def __init__(self, *, headless: bool = False, maximized: bool = True, timeout: int = 10): + """ + Initialize the CopilotAuthenticator. + + You must set the following environment variables for authentication: + - COPILOT_USERNAME: Your Microsoft account username (email). + - COPILOT_PASSWORD: Your Microsoft account password. + + Args: + headless (bool): Whether to run the browser in headless mode. Default is False. + maximized (bool): Whether to start the browser maximized. Default is True. + timeout (int): Timeout for page operations in seconds. Default is 10. + + Raises: + ValueError: If the required environment variables are not set. + """ + super().__init__() + + import os + + self._username = os.getenv("COPILOT_USERNAME") + self._password = os.getenv("COPILOT_PASSWORD") + + self._headless = headless + self._maximized = maximized + self._timeout = timeout * 1000 # ms + + if not self._username or not self._password: + raise ValueError("COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set.") + + def refresh_token(self) -> str: + """ + Refresh the authentication token. + + Returns: + str: The refreshed authentication token. + """ + raise NotImplementedError("refresh_token method not implemented") + + async def get_token(self) -> str: + """ + Get the current authentication token. + + Returns: + str: The current authentication token. + """ + return await self._fetch_access_token_with_playwright() + + async def _fetch_access_token_with_playwright(self) -> Optional[str]: + """ + Fetch access token using Playwright browser automation. + + Raises: + RuntimeError: If Playwright is not installed. + + Returns: + Optional[str]: The bearer token if successfully retrieved, else None. + """ + try: + from playwright.async_api import async_playwright + + pass + except ImportError: + raise RuntimeError("Playwright is not installed. Please install it with 'pip install playwright'.") + + bearer_token = None + + async with async_playwright() as playwright: + browser = None + context = None + + try: + logger.info(f"Launching browser for authentication (headless={self._headless})...") + browser = await playwright.chromium.launch( + headless=self._headless, args=["--start-maximized"] if self._maximized else [] + ) + + context = await browser.new_context(no_viewport=True) + page = await context.new_page() + + # response_handler >>> + async def response_handler(response): + nonlocal bearer_token + + try: + url = response.url + + if "/oauth2/v2.0/token" in url: + try: + text = await response.text() + + if ( + '"token_type":"Bearer"' in text or '"tokenType":"Bearer"' in text + ) and "sydney" in text: + try: + data = json.loads(text) + if "access_token" in data: + bearer_token = data["access_token"] + + except json.JSONDecodeError: + logger.info("Response JSON decode failed, trying regex extraction...") + + match = re.search(r'"access_token"\s*:\s*"([^"]+)"', text) + if match: + bearer_token = match.group(1) + logger.info("Captured bearer token using regex.") + else: + logger.error("Failed to extract bearer token using regex.") + + except Exception as e: + logger.error(f"Error reading response: {e}") + + except Exception as e: + logger.error(f"Error handling response: {e}") + + # ^^^ response_handler + + page.on("response", response_handler) + + logger.info("Navigating to Office.com for authentication...") + await page.goto("https://www.office.com/") + + logger.info("Waiting for profile icon...") + await page.wait_for_selector("#mectrl_headerPicture", timeout=self._timeout) + await page.click("#mectrl_headerPicture") + + logger.info("Waiting for email input...") + await page.wait_for_selector("#i0116", timeout=self._timeout) + await page.fill("#i0116", self._username) + await page.click("#idSIButton9") + + logger.info("Waiting for password input...") + await page.wait_for_selector("#i0118", timeout=self._timeout) + await page.fill("#i0118", self._password) + await page.click("#idSIButton9") + + logger.info("Waiting for 'Stay signed in?' prompt...") + await page.wait_for_selector("#idSIButton9", timeout=self._timeout) + logger.info("Clicking 'Yes' to stay signed in...") + await page.click("#idSIButton9") + + logger.info("Successfully logged in.") + logger.info("Navigating to Copilot...") + + logger.info("Waiting for Copilot button and clicking it...") + await page.wait_for_selector('div[aria-label="M365 Copilot"]', timeout=self._timeout) + await page.click('div[aria-label="M365 Copilot"]', timeout=self._timeout) + + logger.info("Waiting 60 seconds for bearer token to be captured...") + for _ in range(60): + if bearer_token: + break + await asyncio.sleep(1) + + if bearer_token: + logger.info( + f"Bearer token successfully retrieved. Preview: {bearer_token[:16]}...{bearer_token[-16:]}" + ) + else: + logger.error("Failed to retrieve bearer token within 60 seconds.") + + return bearer_token + except Exception as e: + logger.error("Failed to retrieve access token using Playwright.") + + if str(e).startswith("BrowserType.launch"): + logger.error("Playwright browser launch failed. Did you run 'playwright install chromium'?") + else: + logger.error(f"Error details: {e}") + + return None + finally: + logger.info("Gracefully closing Playwright browser instance...") + + if context: + await context.close() + if browser: + await browser.close() From cbc06f05a1718e53c1e78bb5d41aa8b58da3a49b Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Sun, 28 Dec 2025 18:43:28 +0100 Subject: [PATCH 31/38] use `msal-extensions` for encrypted token persistence --- .gitignore | 1 + pyrit/auth/copilot_authenticator.py | 164 +++++++++++++++++++++++++--- pyrit/common/path.py | 4 + 3 files changed, 153 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 24e00cf0b..2d6f8bfef 100644 --- a/.gitignore +++ b/.gitignore @@ -169,6 +169,7 @@ cython_debug/ # PyRIT secrets file .env +.pyrit_cache/ # Cache for generating docs doc/generate_docs/cache/* diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index aeb6d2d04..907916041 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -3,70 +3,195 @@ import logging import asyncio +import os +from datetime import datetime, timezone from typing import Optional import json import re +from msal_extensions import build_encrypted_persistence, FilePersistence from pyrit.auth.authenticator import Authenticator +from pyrit.common.path import PYRIT_CACHE_PATH logger = logging.getLogger(__name__) class CopilotAuthenticator(Authenticator): """ - Playwright-based authenticator for Microsoft Copilot. + Playwright-based authenticator for Microsoft Copilot. Used by WebSocketCopilotTarget. - You need to have playwright installed and set up: ``pip install playwright && playwright install chromium``. - """ + This authenticator automates browser login to obtain and refresh access tokens that are necessary for accessing + Microsoft Copilot via WebSocket connections. It uses Playwright to simulate user interactions for authentication, and msal-extensions for encrypted token persistence. - def __init__(self, *, headless: bool = False, maximized: bool = True, timeout: int = 10): - """ - Initialize the CopilotAuthenticator. + An access token acquired by this authenticator is usually valid for about 60 minutes. + + Note: + To be able to use this authenticator, you must set the following environment variables: - You must set the following environment variables for authentication: - COPILOT_USERNAME: Your Microsoft account username (email). - COPILOT_PASSWORD: Your Microsoft account password. + Additionally, you need to have playwright installed and set up: + ``pip install playwright && playwright install chromium``. + """ + + CACHE_FILE_NAME: str = "copilot_token_cache.bin" + + def __init__( + self, + *, + headless: bool = False, + maximized: bool = True, + timeout_for_elements: int = 10, + fallback_to_plaintext: bool = False, + ): + """ + Initialize the CopilotAuthenticator. + Args: headless (bool): Whether to run the browser in headless mode. Default is False. maximized (bool): Whether to start the browser maximized. Default is True. - timeout (int): Timeout for page operations in seconds. Default is 10. + timeout_for_elements (int): Timeout used when waiting for page elements, in seconds. Default is 10. + fallback_to_plaintext (bool): Whether to fallback to plaintext storage if encryption is unavailable. + If set to False (default), an exception will be raised if encryption cannot be used. Raises: ValueError: If the required environment variables are not set. """ super().__init__() - import os - self._username = os.getenv("COPILOT_USERNAME") self._password = os.getenv("COPILOT_PASSWORD") self._headless = headless self._maximized = maximized - self._timeout = timeout * 1000 # ms + self._timeout = timeout_for_elements * 1000 # ms + self._fallback_to_plaintext = fallback_to_plaintext + + self._cache_dir = PYRIT_CACHE_PATH + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._cache_file = str(self._cache_dir / self.CACHE_FILE_NAME) if not self._username or not self._password: raise ValueError("COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set.") - def refresh_token(self) -> str: + self._token_cache = self._create_persistent_cache(self._cache_file, self._fallback_to_plaintext) + + @staticmethod + def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = False): + # https://github.com/AzureAD/microsoft-authentication-extensions-for-python + + try: + logger.info(f"Using encrypted persistent token cache: {cache_file}") + return build_encrypted_persistence(cache_file) + except Exception as e: + if fallback_to_plaintext: + logger.warning(f"Encryption unavailable ({e}). Opting in to plain text.") + return FilePersistence(cache_file) + logger.error("Encryption unavailable and fallback_to_plaintext is False.") + raise + + def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: + try: + cache_data = self._token_cache.load() + if not cache_data: + logger.info("No cached token data found.") + return None + + token_data = json.loads(cache_data) + if "access_token" not in token_data: + logger.info("No access token in cache.") + return None + + expires_at = token_data.get("expires_at") + if expires_at: + expiry_time = datetime.fromtimestamp(expires_at, tz=timezone.utc) + current_time = datetime.now(timezone.utc) + + # TODO: add n-minute buffer to avoid using tokens about to expire + if current_time >= expiry_time: + logger.info("Cached token has expired.") + return None + + minutes_left = (expiry_time - current_time).total_seconds() / 60 + logger.info(f"Cached token is valid for another {minutes_left:.2f} minutes") + + return token_data + + except Exception as e: + error_name = type(e).__name__ + if "PersistenceNotFound" in error_name or "FileNotFoundError" in error_name: + logger.info("Cache file does not exist yet. Will be created on first token save.") + else: + logger.error(f"Failed to load cached token ({error_name}): {e}") + return None + + def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None: + token_data = { + "access_token": token, + "token_type": "Bearer", + "cached_at": datetime.now(timezone.utc).timestamp(), + } + + if expires_in: + expires_at = datetime.now(timezone.utc).timestamp() + expires_in + token_data["expires_at"] = expires_at + token_data["expires_in"] = expires_in + + try: + self._token_cache.save(json.dumps(token_data)) + logger.info("Token successfully cached.") + except Exception as e: + logger.error(f"Failed to cache token: {e}") + + def _clear_token_cache(self) -> None: + try: + self._token_cache.save(json.dumps({})) + logger.info("Token cache cleared.") + except Exception as e: + logger.error(f"Failed to clear cache: {e}") + + async def refresh_token(self) -> str: """ - Refresh the authentication token. + Refresh the authentication token asynchronously. + + This will clear the existing token cache and fetch a new token with automated browser login. Returns: str: The refreshed authentication token. + + Raises: + RuntimeError: If token refresh fails. """ - raise NotImplementedError("refresh_token method not implemented") + logger.info("Refreshing access token...") + self._clear_token_cache() + token = await self._fetch_access_token_with_playwright() + + if not token: + raise RuntimeError("Failed to refresh access token.") + + return token async def get_token(self) -> str: """ Get the current authentication token. + This will check the cache first and only launch the browser if no valid token is found. + Returns: str: The current authentication token. + + Raises: + RuntimeError: If token retrieval fails. """ - return await self._fetch_access_token_with_playwright() + cached_token = self._get_cached_token_if_available_and_valid() + if cached_token and "access_token" in cached_token: + logger.info("Using cached access token.") + return cached_token["access_token"] + + logger.info("No valid cached token found.") + return await self.refresh_token() async def _fetch_access_token_with_playwright(self) -> Optional[str]: """ @@ -86,6 +211,7 @@ async def _fetch_access_token_with_playwright(self) -> Optional[str]: raise RuntimeError("Playwright is not installed. Please install it with 'pip install playwright'.") bearer_token = None + token_expires_in = None async with async_playwright() as playwright: browser = None @@ -102,7 +228,7 @@ async def _fetch_access_token_with_playwright(self) -> Optional[str]: # response_handler >>> async def response_handler(response): - nonlocal bearer_token + nonlocal bearer_token, token_expires_in try: url = response.url @@ -118,6 +244,7 @@ async def response_handler(response): data = json.loads(text) if "access_token" in data: bearer_token = data["access_token"] + token_expires_in = data.get("expires_in") except json.JSONDecodeError: logger.info("Response JSON decode failed, trying regex extraction...") @@ -126,6 +253,10 @@ async def response_handler(response): if match: bearer_token = match.group(1) logger.info("Captured bearer token using regex.") + + expires_match = re.search(r'"expires_in"\s*:\s*(\d+)', text) + if expires_match: + token_expires_in = int(expires_match.group(1)) else: logger.error("Failed to extract bearer token using regex.") @@ -178,6 +309,7 @@ async def response_handler(response): logger.info( f"Bearer token successfully retrieved. Preview: {bearer_token[:16]}...{bearer_token[-16:]}" ) + self._save_token_to_cache(token=bearer_token, expires_in=token_expires_in) else: logger.error("Failed to retrieve bearer token within 60 seconds.") diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 40340f28a..14158d3ff 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -41,6 +41,10 @@ def in_git_repo() -> bool: DB_DATA_PATH = get_default_data_path("dbdata") DB_DATA_PATH.mkdir(parents=True, exist_ok=True) +# Path to where cache files are stored, i.e. token cache, etc. +PYRIT_CACHE_PATH = get_default_data_path(".pyrit_cache") +PYRIT_CACHE_PATH.mkdir(parents=True, exist_ok=True) + # Path to where the logs are located LOG_PATH = pathlib.Path(DB_DATA_PATH, "logs.txt").resolve() LOG_PATH.touch(exist_ok=True) From 4299673c210d591ff78620d0d103bffab914c74d Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Sun, 28 Dec 2025 22:04:25 +0100 Subject: [PATCH 32/38] add CopilotAuthenticator to WebSocketCopilotTarget for automated authentication --- pyrit/auth/__init__.py | 2 + pyrit/auth/copilot_authenticator.py | 3 + .../prompt_target/websocket_copilot_target.py | 154 +++++++++++------- 3 files changed, 98 insertions(+), 61 deletions(-) diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 813f680ab..1f273e26c 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -15,11 +15,13 @@ get_default_azure_scope, ) from pyrit.auth.azure_storage_auth import AzureStorageAuth +from pyrit.auth.copilot_authenticator import CopilotAuthenticator __all__ = [ "Authenticator", "AzureAuth", "AzureStorageAuth", + "CopilotAuthenticator", "TokenProviderCredential", "get_azure_token_provider", "get_azure_async_token_provider", diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index 907916041..1c2bb5e49 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -93,6 +93,7 @@ def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = Fals raise def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: + # TODO: make sure the cached token matches the proper user account try: cache_data = self._token_cache.load() if not cache_data: @@ -185,6 +186,7 @@ async def get_token(self) -> str: Raises: RuntimeError: If token retrieval fails. """ + # TODO: make sure multiple concurrent calls don't launch multiple browsers cached_token = self._get_cached_token_if_available_and_valid() if cached_token and "access_token" in cached_token: logger.info("Using cached access token.") @@ -203,6 +205,7 @@ async def _fetch_access_token_with_playwright(self) -> Optional[str]: Returns: Optional[str]: The bearer token if successfully retrieved, else None. """ + # TODO: it's a long function, maybe split into smaller ones? try: from playwright.async_api import async_playwright diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 58a380ecd..7e83c0ee4 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -4,13 +4,14 @@ import asyncio import json import logging -import os import uuid from enum import IntEnum from typing import Optional +import jwt import websockets +from pyrit.auth import CopilotAuthenticator from pyrit.exceptions import ( EmptyResponseException, pyrit_target_retry, @@ -41,26 +42,26 @@ class WebSocketCopilotTarget(PromptTarget): A WebSocket-based prompt target for Microsoft Copilot integration. This target enables communication with Microsoft Copilot through a WebSocket connection. - Currently, authentication requires manually extracting a WebSocket URL from an active browser session. - In the future, more flexible authentication mechanisms will be added. - - To obtain the WebSocket URL: - 1. Ensure you are logged into Microsoft 365 with access to Copilot - 2. Navigate to https://m365.cloud.microsoft/chat or open Copilot in https://teams.microsoft.com/v2 - 3. Open browser developer tools and switch to the Network tab - 4. Begin typing or send a message to Copilot to establish the WebSocket connection - 5. Search the network requests for "chathub", "conversation", or "access_token" - 6. Identify the WebSocket connection (look for WS protocol) and copy its full URL - - Warning: - All target instances using the same `WEBSOCKET_URL` will share a single conversation session. + Authentication is handled automatically using CopilotAuthenticator, which uses Playwright + to automate browser login and obtain access tokens. + + Requirements: + Set the following environment variables: + - COPILOT_USERNAME: Your Microsoft account username (email). + - COPILOT_PASSWORD: Your Microsoft account password. + + Install Playwright and its browser dependencies: + pip install playwright + playwright install chromium + + Note: Only works with licensed Microsoft 365 Copilot. The free Copilot version is not compatible. + Each target instance creates a new conversation session with unique conversation and session IDs. """ - # TODO: add more flexible auth, use puppeteer? https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L248 - SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? + WEBSOCKET_BASE_URL: str = "wss://substrate.office.com/m365Copilot/Chathub" RESPONSE_TIMEOUT_SECONDS: int = 60 CONNECTION_TIMEOUT_SECONDS: int = 30 @@ -71,53 +72,41 @@ def __init__( max_requests_per_minute: Optional[int] = None, model_name: str = "copilot", response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, + authenticator: Optional[CopilotAuthenticator] = None, ) -> None: """ Initialize the WebSocketCopilotTarget. Args: verbose (bool): Enable verbose logging. Defaults to False. - max_requests_per_minute (int, Optional): Maximum number of requests per minute. + max_requests_per_minute (Optional[int]): Maximum number of requests per minute. model_name (str): The model name. Defaults to "copilot". response_timeout_seconds (int): Timeout for receiving responses in seconds. Defaults to 60s. + authenticator (Optional[CopilotAuthenticator]): Authenticator instance for token management. + If None, a new CopilotAuthenticator instance will be created with default settings. Raises: - ValueError: If WebSocket URL is not provided, is empty, or has invalid format. - ValueError: If required parameters are missing or empty in the WebSocket URL. + ValueError: If ``response_timeout_seconds`` is not a positive integer. """ - self._websocket_url = os.getenv("WEBSOCKET_URL") - if not self._websocket_url or self._websocket_url.strip() == "": - raise ValueError("WebSocket URL must be provided through the WEBSOCKET_URL environment variable") - - if not self._websocket_url.startswith("wss://"): - raise ValueError(f"WebSocket URL must start with 'wss://'. Received: {self._websocket_url[:10]}") + if response_timeout_seconds <= 0: + raise ValueError("response_timeout_seconds must be a positive integer.") - if "ConversationId=" not in self._websocket_url: - raise ValueError("`ConversationId` parameter not found in WebSocket URL.") - self._conversation_id = self._websocket_url.split("ConversationId=")[1].split("&")[0] - if not self._conversation_id: - raise ValueError("`ConversationId` parameter is empty in WebSocket URL.") + self._authenticator = authenticator or CopilotAuthenticator() + self._response_timeout_seconds = response_timeout_seconds - if "X-SessionId=" not in self._websocket_url: - raise ValueError("`X-SessionId` parameter not found in WebSocket URL.") - self._session_id = self._websocket_url.split("X-SessionId=")[1].split("&")[0] - if not self._session_id: - raise ValueError("`X-SessionId` parameter is empty in WebSocket URL.") + # These will be generated fresh for each request + self._session_id: Optional[str] = None + self._conversation_id: Optional[str] = None super().__init__( verbose=verbose, max_requests_per_minute=max_requests_per_minute, - endpoint=self._websocket_url.split("?")[0], # wss://substrate.office.com/m365Copilot/Chathub/... + endpoint=self.WEBSOCKET_BASE_URL, model_name=model_name, ) - if response_timeout_seconds <= 0: - raise ValueError("response_timeout_seconds must be a positive integer.") - self._response_timeout_seconds = response_timeout_seconds - if self._verbose: - logger.info(f"WebSocketCopilotTarget initialized with conversation_id: {self._conversation_id}") - logger.info(f"Session ID: {self._session_id}") + logger.info("WebSocketCopilotTarget initialized") @staticmethod def _dict_to_websocket(data: dict) -> str: @@ -175,14 +164,53 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: return results if results else [(CopilotMessageType.UNKNOWN, "")] + async def _build_websocket_url_async(self) -> str: + access_token = await self._authenticator.get_token() + + try: + parsed_token = jwt.decode(access_token, algorithms=["RS256"], options={"verify_signature": False}) + except Exception as e: + raise ValueError(f"Failed to decode access token: {str(e)}") from e + + tenant_id = parsed_token.get("tid") + object_id = parsed_token.get("oid") + + if not tenant_id or not object_id: + raise ValueError( + "Failed to extract tenant_id (tid) or object_id (oid) from bearer token. " + f"Token claims: {list(parsed_token.keys())}" + ) + + self._session_id = str(uuid.uuid4()) + self._conversation_id = str(uuid.uuid4()) + client_request_id = str(uuid.uuid4()) + + base_url = f"{self.WEBSOCKET_BASE_URL}/{object_id}@{tenant_id}" + query_params = [ + f"ClientRequestId={client_request_id}", + f"X-SessionId={self._session_id}", + f"ConversationId={self._conversation_id}", + f"access_token={access_token}", + "X-variants=feature.includeExternal,feature.AssistantConnectorsContentSources," + "3S.BizChatWprBoostAssistant,3S.EnableMEFromSkillDiscovery,feature.EnableAuthErrorMessage," + "EnableRequestPlugins,feature.EnableSensitivityLabels,feature.IsEntityAnnotationsEnabled," + "EnableUnsupportedUrlDetector", + "source=%22officeweb%22", + "scenario=OfficeWebIncludedCopilot", + ] + + websocket_url = f"{base_url}?{'&'.join(query_params)}" + logger.debug(f"WebSocket URL: {websocket_url}") + return websocket_url + def _build_prompt_message(self, prompt: str) -> dict: + request_id = trace_id = uuid.uuid4().hex + return { "arguments": [ { - "source": "officeweb", # TODO: support 'teamshub' as well - # TODO: not sure whether to uuid.uuid4() or use a static like it's done in power-pwn - # https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L156 - "clientCorrelationId": str(uuid.uuid4()), + "source": "officeweb", + "clientCorrelationId": uuid.uuid4().hex, "sessionId": self._session_id, "optionsSets": [ "enterprise_flux_web", @@ -218,11 +246,10 @@ def _build_prompt_message(self, prompt: str) -> dict: "DeveloperLogs", ], "sliceIds": [], - # TODO: enable using agents https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py#L192 "threadLevelGptId": {}, "conversationId": self._conversation_id, - "traceId": str(uuid.uuid4()).replace("-", ""), # TODO: same case as clientCorrelationId - "isStartOfSession": 0, + "traceId": trace_id, + "isStartOfSession": True, "productThreadType": "Office", "clientInfo": {"clientPlatform": "web"}, "message": { @@ -230,29 +257,29 @@ def _build_prompt_message(self, prompt: str) -> dict: "inputMethod": "Keyboard", "text": prompt, "entityAnnotationTypes": ["People", "File", "Event", "Email", "TeamsMessage"], - "requestId": str(uuid.uuid4()).replace("-", ""), + "requestId": request_id, "locationInfo": {"timeZoneOffset": 0, "timeZone": "UTC"}, "locale": "en-US", "messageType": "Chat", "experienceType": "Default", }, - "plugins": [], # TODO: support enabling some plugins? + "plugins": [], } ], - "invocationId": "0", # TODO: should be dynamic? + "invocationId": "0", "target": "chat", "type": CopilotMessageType.USER_PROMPT, } async def _connect_and_send(self, prompt: str) -> str: - protocol_msg = {"protocol": "json", "version": 1} - prompt_dict = self._build_prompt_message(prompt) + websocket_url = await self._build_websocket_url_async() - inputs = [protocol_msg, prompt_dict] - last_response = "" + # TODO: explain why PING is not sent here + inputs = [{"protocol": "json", "version": 1}, self._build_prompt_message(prompt)] + response = "" async with websockets.connect( - self._websocket_url, + websocket_url, open_timeout=self.CONNECTION_TIMEOUT_SECONDS, close_timeout=self.CONNECTION_TIMEOUT_SECONDS, ) as websocket: @@ -260,6 +287,8 @@ async def _connect_and_send(self, prompt: str) -> str: payload = self._dict_to_websocket(input_msg) await websocket.send(payload) + is_user_input = input_msg.get("type") == CopilotMessageType.USER_PROMPT + stop_polling = False while not stop_polling: try: @@ -288,11 +317,14 @@ async def _connect_and_send(self, prompt: str) -> str: stop_polling = True if msg_type == CopilotMessageType.FINAL_CONTENT: - last_response = content + response = content elif msg_type == CopilotMessageType.UNKNOWN: logger.debug("Received unknown or empty message type.") - return last_response + elif msg_type == CopilotMessageType.PING and not is_user_input: + stop_polling = True + + return response def _validate_request(self, *, message: Message) -> None: n_pieces = len(message.message_pieces) @@ -332,7 +364,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: if not response_text or not response_text.strip(): logger.error("Empty response received from Copilot.") raise EmptyResponseException(message="Copilot returned an empty response.") - logger.info(f"Received the following response from WebSocketCopilotTarget: {response_text[:100]}...") + logger.info(f"Received the following response from WebSocketCopilotTarget: \n{response_text}") response_entry = construct_response_from_request( request=request_piece, response_text_pieces=[response_text] @@ -343,7 +375,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: except websockets.exceptions.InvalidStatus as e: logger.error( f"WebSocket connection failed: {str(e)}\n" - "Ensure the WEBSOCKET_URL environment variable is correct and valid." + "Ensure that COPILOT_USERNAME and COPILOT_PASSWORD environment variables are set correctly." " For more details about authentication, refer to the class documentation." ) raise From c65205bc4d0b763f79c094699a2c32e0d45829ec Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Mon, 29 Dec 2025 14:57:51 +0100 Subject: [PATCH 33/38] WORKING multi prompt example --- doc/api.rst | 1 + pyrit/auth/copilot_authenticator.py | 5 +- .../prompt_target/websocket_copilot_target.py | 125 +++++++++++++----- websocket_copilot_simple_example.py | 38 ++++-- 4 files changed, 124 insertions(+), 45 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index cd812de72..f95617b95 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -30,6 +30,7 @@ API Reference Authenticator AzureAuth AzureStorageAuth + CopilotAuthenticator :py:mod:`pyrit.auxiliary_attacks` ================================= diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index 1c2bb5e49..e3cf4167d 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -28,9 +28,8 @@ class CopilotAuthenticator(Authenticator): Note: To be able to use this authenticator, you must set the following environment variables: - - - COPILOT_USERNAME: Your Microsoft account username (email). - - COPILOT_PASSWORD: Your Microsoft account password. + - COPILOT_USERNAME: your Microsoft account username (email) + - COPILOT_PASSWORD: your Microsoft account password Additionally, you need to have playwright installed and set up: ``pip install playwright && playwright install chromium``. diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 7e83c0ee4..86ef4fc4f 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -39,24 +39,33 @@ class CopilotMessageType(IntEnum): class WebSocketCopilotTarget(PromptTarget): """ - A WebSocket-based prompt target for Microsoft Copilot integration. + A WebSocket-based prompt target for integrating with Microsoft Copilot. - This target enables communication with Microsoft Copilot through a WebSocket connection. - Authentication is handled automatically using CopilotAuthenticator, which uses Playwright - to automate browser login and obtain access tokens. + This class facilitates communication with Microsoft Copilot over a WebSocket connection. + Authentication is handled automatically via `CopilotAuthenticator`, which uses Playwright + to automate browser login and obtain the required access tokens. - Requirements: - Set the following environment variables: - - COPILOT_USERNAME: Your Microsoft account username (email). - - COPILOT_PASSWORD: Your Microsoft account password. + Once authenticated, the target supports multi-turn conversations through server-side + state management. For each PyRIT conversation, it automatically generates consistent + `session_id` and `conversation_id` values, enabling Copilot to preserve conversational + context across multiple turns. - Install Playwright and its browser dependencies: - pip install playwright - playwright install chromium + Because conversation state is managed entirely on the Copilot server, this target does + not resend conversation history with each request and does not support programmatic + inspection or manipulation of that history. At present, there appears to be no supported + mechanism for modifying Copilot's server-side conversation state. Note: - Only works with licensed Microsoft 365 Copilot. The free Copilot version is not compatible. - Each target instance creates a new conversation session with unique conversation and session IDs. + This integration only works with licensed Microsoft 365 Copilot. + The free version of Copilot is not compatible. + + Important: + - Ensure the following environment variables are set: + - ``COPILOT_USERNAME`` - your Microsoft account username (email) + - ``COPILOT_PASSWORD`` - your Microsoft account password + + - Install `Playwright` and its browser dependencies: + ``pip install playwright && playwright install chromium`` """ SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? @@ -94,10 +103,6 @@ def __init__( self._authenticator = authenticator or CopilotAuthenticator() self._response_timeout_seconds = response_timeout_seconds - # These will be generated fresh for each request - self._session_id: Optional[str] = None - self._conversation_id: Optional[str] = None - super().__init__( verbose=verbose, max_requests_per_minute=max_requests_per_minute, @@ -164,7 +169,7 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: return results if results else [(CopilotMessageType.UNKNOWN, "")] - async def _build_websocket_url_async(self) -> str: + async def _build_websocket_url_async(self, *, session_id: str, copilot_conversation_id: str) -> str: access_token = await self._authenticator.get_token() try: @@ -181,15 +186,13 @@ async def _build_websocket_url_async(self) -> str: f"Token claims: {list(parsed_token.keys())}" ) - self._session_id = str(uuid.uuid4()) - self._conversation_id = str(uuid.uuid4()) client_request_id = str(uuid.uuid4()) base_url = f"{self.WEBSOCKET_BASE_URL}/{object_id}@{tenant_id}" query_params = [ f"ClientRequestId={client_request_id}", - f"X-SessionId={self._session_id}", - f"ConversationId={self._conversation_id}", + f"X-SessionId={session_id}", + f"ConversationId={copilot_conversation_id}", f"access_token={access_token}", "X-variants=feature.includeExternal,feature.AssistantConnectorsContentSources," "3S.BizChatWprBoostAssistant,3S.EnableMEFromSkillDiscovery,feature.EnableAuthErrorMessage," @@ -203,7 +206,9 @@ async def _build_websocket_url_async(self) -> str: logger.debug(f"WebSocket URL: {websocket_url}") return websocket_url - def _build_prompt_message(self, prompt: str) -> dict: + def _build_prompt_message( + self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool + ) -> dict: request_id = trace_id = uuid.uuid4().hex return { @@ -211,7 +216,7 @@ def _build_prompt_message(self, prompt: str) -> dict: { "source": "officeweb", "clientCorrelationId": uuid.uuid4().hex, - "sessionId": self._session_id, + "sessionId": session_id, "optionsSets": [ "enterprise_flux_web", "enterprise_flux_work", @@ -247,9 +252,9 @@ def _build_prompt_message(self, prompt: str) -> dict: ], "sliceIds": [], "threadLevelGptId": {}, - "conversationId": self._conversation_id, + "conversationId": copilot_conversation_id, "traceId": trace_id, - "isStartOfSession": True, + "isStartOfSession": is_start_of_session, "productThreadType": "Office", "clientInfo": {"clientPlatform": "web"}, "message": { @@ -271,11 +276,22 @@ def _build_prompt_message(self, prompt: str) -> dict: "type": CopilotMessageType.USER_PROMPT, } - async def _connect_and_send(self, prompt: str) -> str: - websocket_url = await self._build_websocket_url_async() + async def _connect_and_send( + self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool + ) -> str: + websocket_url = await self._build_websocket_url_async( + session_id=session_id, copilot_conversation_id=copilot_conversation_id + ) - # TODO: explain why PING is not sent here - inputs = [{"protocol": "json", "version": 1}, self._build_prompt_message(prompt)] + inputs = [ + {"protocol": "json", "version": 1}, + self._build_prompt_message( + prompt=prompt, + session_id=session_id, + copilot_conversation_id=copilot_conversation_id, + is_start_of_session=is_start_of_session, + ), + ] response = "" async with websockets.connect( @@ -335,12 +351,40 @@ def _validate_request(self, *, message: Message) -> None: if piece_type != "text": raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") + def _is_start_of_session(self, *, conversation_id: str) -> bool: + conversation_history = self._memory.get_conversation(conversation_id=conversation_id) + return len(conversation_history) == 0 + + def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tuple[str, str]: + """ + Generate consistent Copilot session_id and conversation_id for a PyRIT conversation. + + This uses a deterministic approach to ensure that the same PyRIT conversation_id + always maps to the same Copilot session identifiers. This enables multi-turn + conversations while keeping the target stateless. + + Args: + pyrit_conversation_id (str): The PyRIT conversation ID from the Message. + + Returns: + tuple[str, str]: A tuple of (session_id, copilot_conversation_id). + """ + namespace = uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8") # DNS namespace UUID + session_id = str(uuid.uuid5(namespace, f"session_{pyrit_conversation_id}")) + copilot_conversation_id = str(uuid.uuid5(namespace, f"copilot_{pyrit_conversation_id}")) + + return session_id, copilot_conversation_id + @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Asynchronously send a message to Microsoft Copilot using WebSocket. + This method enables multi-turn conversations by using consistent session and conversation + identifiers derived from the PyRIT conversation_id. The Copilot API maintains conversation + state server-side, so only the current message is sent (no explicit history required). + Args: message (Message): A message to be sent to the target. @@ -355,16 +399,31 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: self._validate_request(message=message) request_piece = message.message_pieces[0] - logger.info(f"Sending the following prompt to WebSocketCopilotTarget: {request_piece}") + pyrit_conversation_id = request_piece.conversation_id + is_start_of_session = self._is_start_of_session(conversation_id=pyrit_conversation_id) + + session_id, copilot_conversation_id = self._generate_consistent_copilot_ids( + pyrit_conversation_id=pyrit_conversation_id + ) + + logger.info( + f"Sending prompt to WebSocketCopilotTarget: {request_piece.converted_value} " + f"(conversation_id={pyrit_conversation_id}, is_start={is_start_of_session})" + ) try: prompt_text = request_piece.converted_value - response_text = await self._connect_and_send(prompt_text) + response_text = await self._connect_and_send( + prompt=prompt_text, + session_id=session_id, + copilot_conversation_id=copilot_conversation_id, + is_start_of_session=is_start_of_session, + ) if not response_text or not response_text.strip(): logger.error("Empty response received from Copilot.") raise EmptyResponseException(message="Copilot returned an empty response.") - logger.info(f"Received the following response from WebSocketCopilotTarget: \n{response_text}") + logger.info(f"Received response from WebSocketCopilotTarget (length: {len(response_text)} chars)") response_entry = construct_response_from_request( request=request_piece, response_text_pieces=[response_text] diff --git a/websocket_copilot_simple_example.py b/websocket_copilot_simple_example.py index a1e13831a..f23ae8358 100644 --- a/websocket_copilot_simple_example.py +++ b/websocket_copilot_simple_example.py @@ -5,26 +5,46 @@ import asyncio +from pyrit.executor.attack import MultiPromptSendingAttack from pyrit.models import Message, MessagePiece from pyrit.prompt_target import WebSocketCopilotTarget from pyrit.setup import IN_MEMORY, initialize_pyrit_async +from pyrit.executor.attack import ConsoleAttackResultPrinter async def main(): await initialize_pyrit_async(memory_db_type=IN_MEMORY) + target = WebSocketCopilotTarget() - message_piece = MessagePiece( - role="user", - original_value="say only one random word", - original_value_data_type="text", - converted_value_data_type="text", + prompts = [ + "I'm thinking of a number between 1 and 10.", + "It's greater than 5.", + "It's an even number.", + "What number am I thinking of?", + ] + + messages = [] + for prompt in prompts: + message_piece = MessagePiece( + role="user", + original_value=prompt, + original_value_data_type="text", + converted_value_data_type="text", + ) + messages.append(Message(message_pieces=[message_piece])) + + attack = MultiPromptSendingAttack( + objective_target=target, + ) + + result = await attack.execute_async( + objective="Engage in a multi-turn conversation about a number guessing game", + messages=messages, ) - message = Message(message_pieces=[message_piece]) - responses = await target.send_prompt_async(message=message) - for response in responses: - print(f"{response.get_value()}") + printer = ConsoleAttackResultPrinter() + await printer.print_conversation_async(result=result) # type: ignore if __name__ == "__main__": From fd745ba5dd1bc7a721ebfbd7dbd1abc91ad671c6 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Mon, 29 Dec 2025 19:27:49 +0100 Subject: [PATCH 34/38] unit tests update for `WebSocketCopilotTarget` --- .../prompt_target/websocket_copilot_target.py | 9 +- .../target/test_websocket_copilot_target.py | 495 +++++++++++++++--- websocket_copilot_simple_example.py | 5 + 3 files changed, 431 insertions(+), 78 deletions(-) diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 86ef4fc4f..56709b4c4 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -210,8 +210,7 @@ def _build_prompt_message( self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool ) -> dict: request_id = trace_id = uuid.uuid4().hex - - return { + result = { "arguments": [ { "source": "officeweb", @@ -276,6 +275,9 @@ def _build_prompt_message( "type": CopilotMessageType.USER_PROMPT, } + logger.debug(f"Built prompt message: {result}") + return result + async def _connect_and_send( self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool ) -> str: @@ -439,5 +441,8 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ) raise + except EmptyResponseException: + raise + except Exception as e: raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 0b4e713e1..1004b49f4 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -1,96 +1,85 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import os -from unittest.mock import patch +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +import jwt import pytest +from pyrit.auth import CopilotAuthenticator +from pyrit.models import Message, MessagePiece from pyrit.prompt_target import WebSocketCopilotTarget - - -VALID_WEBSOCKET_URL = ( - "wss://substrate.office.com/m365Copilot/Chathub/test_object_id@test_tenant_id" - "?ClientRequestId=test_client_request_id" - "&X-SessionId=test_session_id&token=abc123" - "&ConversationId=test_conversation_id" - "&access_token=test_access_token" - # "&variants=feature.test_feature_one,feature.test_feature_two" - # "&agent=web" - # "&scenario=OfficeWebIncludedCopilot" -) +from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType @pytest.fixture -def mock_env_websocket_url(): - with patch.dict(os.environ, {"WEBSOCKET_URL": VALID_WEBSOCKET_URL}): - yield +def mock_authenticator(): + token_payload = {"tid": "test_tenant_id", "oid": "test_object_id", "exp": 9999999999} + mock_token = jwt.encode(token_payload, "secret", algorithm="HS256") + if isinstance(mock_token, bytes): + mock_token = mock_token.decode("utf-8") + authenticator = MagicMock(spec=CopilotAuthenticator) + authenticator.get_token = AsyncMock(return_value=mock_token) + return authenticator @pytest.mark.usefixtures("patch_central_database") class TestWebSocketCopilotTargetInit: - def test_init_with_valid_wss_url(self, mock_env_websocket_url): - target = WebSocketCopilotTarget() - - assert target._websocket_url == VALID_WEBSOCKET_URL - assert target._conversation_id == "test_conversation_id" - assert target._session_id == "test_session_id" - assert target._model_name == "copilot" - - def test_init_with_missing_or_invalid_wss_url(self): - for env_vars in [{}, {"WEBSOCKET_URL": ""}, {"WEBSOCKET_URL": " "}]: - with patch.dict(os.environ, env_vars, clear=True): - with pytest.raises(ValueError, match="WebSocket URL must be provided"): - WebSocketCopilotTarget() - - for invalid_url in ["invalid_websocket_url", "ws://example.com", "https://example.com"]: - with patch.dict(os.environ, {"WEBSOCKET_URL": invalid_url}, clear=True): - with pytest.raises(ValueError, match="WebSocket URL must start with 'wss://'"): - WebSocketCopilotTarget() - - def test_init_with_missing_or_empty_required_params(self): - urls = [ - ("wss://example.com/?X-SessionId=session123", "`ConversationId` parameter not found"), - ("wss://example.com/?ConversationId=conv123", "`X-SessionId` parameter not found"), - ("wss://example.com/?ConversationId=&X-SessionId=session123", "`ConversationId` parameter is empty"), - ("wss://example.com/?ConversationId=conv123&X-SessionId=", "`X-SessionId` parameter is empty"), - ] + def test_init_with_default_parameters(self): + with patch("pyrit.prompt_target.websocket_copilot_target.CopilotAuthenticator") as mock_auth_class: + mock_auth_instance = MagicMock(spec=CopilotAuthenticator) + mock_auth_class.return_value = mock_auth_instance + + target = WebSocketCopilotTarget() - for url, error_msg in urls: - with patch.dict(os.environ, {"WEBSOCKET_URL": url}, clear=True): - with pytest.raises(ValueError, match=error_msg): - WebSocketCopilotTarget() + mock_auth_class.assert_called_once() + assert target._authenticator == mock_auth_instance + assert target._response_timeout_seconds == WebSocketCopilotTarget.RESPONSE_TIMEOUT_SECONDS + assert target._model_name == "copilot" + assert target._endpoint == WebSocketCopilotTarget.WEBSOCKET_BASE_URL + assert target._verbose is False + assert target._max_requests_per_minute is None - def test_init_sets_endpoint_correctly(self, mock_env_websocket_url): - target = WebSocketCopilotTarget() - assert target._endpoint == "wss://substrate.office.com/m365Copilot/Chathub/test_object_id@test_tenant_id" + def test_init_with_custom_parameters(self, mock_authenticator): + target = WebSocketCopilotTarget( + authenticator=mock_authenticator, + verbose=True, + max_requests_per_minute=10, + model_name="custom_copilot", + response_timeout_seconds=120, + ) - def test_init_with_custom_response_timeout(self, mock_env_websocket_url): - target = WebSocketCopilotTarget(response_timeout_seconds=120) + assert target._authenticator == mock_authenticator assert target._response_timeout_seconds == 120 + assert target._model_name == "custom_copilot" + assert target._verbose is True + assert target._max_requests_per_minute == 10 - for invalid_timeout in [0, -10]: + def test_init_with_invalid_response_timeout(self, mock_authenticator): + for invalid_timeout in [0, -10, -1]: with pytest.raises(ValueError, match="response_timeout_seconds must be a positive integer."): - WebSocketCopilotTarget(response_timeout_seconds=invalid_timeout) + WebSocketCopilotTarget(authenticator=mock_authenticator, response_timeout_seconds=invalid_timeout) -@pytest.mark.parametrize( - "data,expected", - [ - ({"key": "value"}, '{"key":"value"}\x1e'), - ({"protocol": "json", "version": 1}, '{"protocol":"json","version":1}\x1e'), - ({"outer": {"inner": "value"}}, '{"outer":{"inner":"value"}}\x1e'), - ({"items": [1, 2, 3]}, '{"items":[1,2,3]}\x1e'), - ], -) -def test_dict_to_websocket_static_method(data, expected): - result = WebSocketCopilotTarget._dict_to_websocket(data) - assert result == expected +@pytest.mark.usefixtures("patch_central_database") +class TestDictToWebsocket: + @pytest.mark.parametrize( + "data,expected", + [ + ({"key": "value"}, '{"key":"value"}\x1e'), + ({"protocol": "json", "version": 1}, '{"protocol":"json","version":1}\x1e'), + ({"outer": {"inner": "value"}}, '{"outer":{"inner":"value"}}\x1e'), + ({"items": [1, 2, 3]}, '{"items":[1,2,3]}\x1e'), + ], + ) + def test_dict_to_websocket_converts_to_json_with_separator(self, data, expected): + result = WebSocketCopilotTarget._dict_to_websocket(data) + assert result == expected +@pytest.mark.usefixtures("patch_central_database") class TestParseRawMessage: - from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType - @pytest.mark.parametrize( "message,expected_types,expected_content", [ @@ -106,10 +95,12 @@ class TestParseRawMessage: ( '{"type":2,"item":{"result":{"message":"Final."}}}\x1e{"type":3,"invocationId":"0"}\x1e', [CopilotMessageType.FINAL_CONTENT, CopilotMessageType.STREAM_END], - [ - "Final.", - "", - ], + ["Final.", ""], + ), + ( + '{"type":3,"invocationId":"0"}\x1e', + [CopilotMessageType.STREAM_END], + [""], ), ], ) @@ -122,8 +113,6 @@ def test_parse_raw_message_with_valid_data(self, message, expected_types, expect assert result[i][1] == expected_content[i] def test_parse_final_message_without_content(self): - from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType - with patch("pyrit.prompt_target.websocket_copilot_target.logger") as mock_logger: message = '{"type":2,"invocationId":"0"}\x1e' result = WebSocketCopilotTarget._parse_raw_message(message) @@ -144,9 +133,363 @@ def test_parse_final_message_without_content(self): ], ) def test_parse_unknown_or_invalid_messages(self, message): - from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType - result = WebSocketCopilotTarget._parse_raw_message(message) assert len(result) == 1 assert result[0][0] == CopilotMessageType.UNKNOWN assert result[0][1] == "" + + +@pytest.mark.usefixtures("patch_central_database") +class TestBuildWebsocketUrl: + @pytest.mark.asyncio + async def test_build_websocket_url_with_valid_token(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + session_id = "test_session_id" + copilot_conversation_id = "test_conversation_id" + + url = await target._build_websocket_url_async( + session_id=session_id, copilot_conversation_id=copilot_conversation_id + ) + expected_token = await mock_authenticator.get_token() + + assert url.startswith(f"{WebSocketCopilotTarget.WEBSOCKET_BASE_URL}/test_object_id@test_tenant_id?") + assert f"X-SessionId={session_id}" in url + assert f"ConversationId={copilot_conversation_id}" in url + assert f"access_token={expected_token}" in url + assert "ClientRequestId=" in url + assert "source=%22officeweb%22" in url + assert "scenario=OfficeWebIncludedCopilot" in url + + @pytest.mark.asyncio + async def test_build_websocket_url_with_missing_ids(self, mock_authenticator): + for missing_id in ["tid", "oid"]: + token_payload = {"tid": "test_tenant_id", "oid": "test_object_id", "exp": 9999999999} + del token_payload[missing_id] + + mock_token = jwt.encode(token_payload, "secret", algorithm="HS256") + if isinstance(mock_token, bytes): + mock_token = mock_token.decode("utf-8") + mock_authenticator.get_token = AsyncMock(return_value=mock_token) + + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + with pytest.raises(ValueError, match="Failed to extract tenant_id \\(tid\\) or object_id \\(oid\\)"): + await target._build_websocket_url_async(session_id="test", copilot_conversation_id="test") + + @pytest.mark.asyncio + async def test_build_websocket_url_with_invalid_token(self, mock_authenticator): + mock_authenticator.get_token = AsyncMock(return_value="invalid_token") + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + with pytest.raises(ValueError, match="Failed to decode access token"): + await target._build_websocket_url_async(session_id="test", copilot_conversation_id="test") + + +@pytest.mark.usefixtures("patch_central_database") +class TestBuildPromptMessage: + def test_build_prompt_message_structure(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + prompt = "Hello Copilot" + session_id = "session_123" + copilot_conversation_id = "conv_456" + is_start_of_session = True + + message = target._build_prompt_message( + prompt=prompt, + session_id=session_id, + copilot_conversation_id=copilot_conversation_id, + is_start_of_session=is_start_of_session, + ) + + assert "arguments" in message + assert "invocationId" in message + assert "target" in message + assert "type" in message + + assert message["target"] == "chat" + assert message["type"] == CopilotMessageType.USER_PROMPT + assert message["invocationId"] == "0" + + args = message["arguments"][0] + assert args["sessionId"] == session_id + assert args["conversationId"] == copilot_conversation_id + assert args["isStartOfSession"] is True + assert args["source"] == "officeweb" + assert args["productThreadType"] == "Office" + + msg = args["message"] + assert msg["text"] == prompt + assert msg["author"] == "user" + assert msg["messageType"] == "Chat" + assert msg["locale"] == "en-US" + + def test_build_prompt_message_with_different_session_states(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message = target._build_prompt_message( + prompt="Follow-up question", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=False, + ) + + args = message["arguments"][0] + assert args["isStartOfSession"] is False + + +@pytest.mark.usefixtures("patch_central_database") +class TestConnectAndSend: + @pytest.mark.asyncio + async def test_connect_and_send_successful_response(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock( + side_effect=[ + '{"type":6}\x1e', # PING response to handshake + '{"type":1}\x1e', # partial response to user prompt + '{"type":2,"item":{"result":{"message":"Hello from Copilot"}}}\x1e', # final content + ] + ) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + response = await target._connect_and_send( + prompt="Hello", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=True, + ) + + assert response == "Hello from Copilot" + assert mock_websocket.send.call_count == 2 # handshake + user prompt + + @pytest.mark.asyncio + async def test_connect_and_send_timeout(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator, response_timeout_seconds=1) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(side_effect=asyncio.TimeoutError()) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + with pytest.raises(TimeoutError, match="Timed out waiting for Copilot response"): + await target._connect_and_send( + prompt="Hello", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=True, + ) + + @pytest.mark.asyncio + async def test_connect_and_send_none_response(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(return_value=None) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + with pytest.raises(RuntimeError, match="WebSocket connection closed unexpectedly"): + await target._connect_and_send( + prompt="Hello", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=True, + ) + + +@pytest.mark.usefixtures("patch_central_database") +class TestValidateRequest: + def test_validate_request_with_single_text_piece(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="test", + converted_value="test", + conversation_id="123", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + # Should not raise + target._validate_request(message=message) + + def test_validate_request_with_multiple_pieces(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_pieces = [ + MessagePiece( + role="user", + original_value="test1", + converted_value="test1", + conversation_id="123", + original_value_data_type="text", + converted_value_data_type="text", + ), + MessagePiece( + role="user", + original_value="test2", + converted_value="test2", + conversation_id="123", + original_value_data_type="text", + converted_value_data_type="text", + ), + ] + message = Message(message_pieces=message_pieces) + + with pytest.raises(ValueError, match="This target only supports a single message piece"): + target._validate_request(message=message) + + def test_validate_request_with_non_text_type(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="image.png", + converted_value="image.png", + conversation_id="123", + original_value_data_type="image_path", + converted_value_data_type="image_path", + ) + message = Message(message_pieces=[message_piece]) + + with pytest.raises(ValueError, match="This target only supports text prompt input"): + target._validate_request(message=message) + + +@pytest.mark.usefixtures("patch_central_database") +class TestIsStartOfSession: + def test_is_start_of_session_with_empty_history(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + conversation_id = "test_conv_123" + result = target._is_start_of_session(conversation_id=conversation_id) + + assert result is True + mock_memory.get_conversation.assert_called_once_with(conversation_id=conversation_id) + + def test_is_start_of_session_with_existing_history(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_memory = MagicMock() + mock_message = MagicMock() + mock_memory.get_conversation.return_value = [mock_message] + target._memory = mock_memory + + conversation_id = "test_conv_123" + result = target._is_start_of_session(conversation_id=conversation_id) + + assert result is False + + +@pytest.mark.usefixtures("patch_central_database") +class TestGenerateConsistentCopilotIds: + def test_generates_consistent_ids_for_same_conversation(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + pyrit_conv_id = "pyrit_conversation_123" + + session_id_1, copilot_conv_id_1 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id) + session_id_2, copilot_conv_id_2 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id) + + assert session_id_1 == session_id_2 + assert copilot_conv_id_1 == copilot_conv_id_2 + + def test_generates_different_ids_for_different_conversations(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + pyrit_conv_id_1 = "pyrit_conversation_123" + pyrit_conv_id_2 = "pyrit_conversation_456" + + session_id_1, copilot_conv_id_1 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id_1) + session_id_2, copilot_conv_id_2 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id_2) + + assert session_id_1 != session_id_2 + assert copilot_conv_id_1 != copilot_conv_id_2 + + def test_generated_ids_are_valid_uuids(self, mock_authenticator): + """Test that generated IDs are valid UUID format and can be parsed.""" + import uuid + + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + pyrit_conv_id = "test_conversation" + session_id, copilot_conv_id = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id) + + # Should be parseable as UUIDs without raising exceptions + uuid.UUID(session_id) + uuid.UUID(copilot_conv_id) + + +@pytest.mark.usefixtures("patch_central_database") +class TestSendPromptAsync: + @pytest.mark.asyncio + async def test_send_prompt_async_successful(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + conversation_id="conv_123", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_message_to_memory = AsyncMock() + target._memory = mock_memory + + with patch.object(target, "_connect_and_send", new=AsyncMock(return_value="Response from Copilot")): + responses = await target.send_prompt_async(message=message) + + assert len(responses) == 1 + assert responses[0].message_pieces[0].converted_value == "Response from Copilot" + assert responses[0].message_pieces[0].role == "assistant" + + @pytest.mark.asyncio + async def test_send_prompt_async_with_exceptions(self, mock_authenticator): + from pyrit.exceptions import EmptyResponseException + + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + conversation_id="conv_123", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + # Test for various empty responses + for response in [None, "", " \n\t "]: + with patch.object(target, "_connect_and_send", new=AsyncMock(return_value=response)): + with pytest.raises(EmptyResponseException, match="Copilot returned an empty response"): + await target.send_prompt_async(message=message) + + # Test for generic exception during WebSocket communication + with patch.object(target, "_connect_and_send", new=AsyncMock(side_effect=Exception("Test error"))): + with pytest.raises(RuntimeError, match="An error occurred during WebSocket communication"): + await target.send_prompt_async(message=message) diff --git a/websocket_copilot_simple_example.py b/websocket_copilot_simple_example.py index f23ae8358..78a729ee7 100644 --- a/websocket_copilot_simple_example.py +++ b/websocket_copilot_simple_example.py @@ -11,6 +11,11 @@ from pyrit.setup import IN_MEMORY, initialize_pyrit_async from pyrit.executor.attack import ConsoleAttackResultPrinter +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + async def main(): await initialize_pyrit_async(memory_db_type=IN_MEMORY) From fd8bc171525276114df8961c93a6be829e2d2976 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Mon, 29 Dec 2025 19:44:54 +0100 Subject: [PATCH 35/38] fixes --- pyrit/auth/copilot_authenticator.py | 31 +++++ .../prompt_target/websocket_copilot_target.py | 118 +++++++++++++++--- .../target/test_websocket_copilot_target.py | 3 + 3 files changed, 134 insertions(+), 18 deletions(-) diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index e3cf4167d..33bb562f5 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -76,6 +76,7 @@ def __init__( raise ValueError("COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set.") self._token_cache = self._create_persistent_cache(self._cache_file, self._fallback_to_plaintext) + self._current_claims = {} @staticmethod def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = False): @@ -166,6 +167,7 @@ async def refresh_token(self) -> str: """ logger.info("Refreshing access token...") self._clear_token_cache() + self._current_claims = {} token = await self._fetch_access_token_with_playwright() if not token: @@ -194,6 +196,35 @@ async def get_token(self) -> str: logger.info("No valid cached token found.") return await self.refresh_token() + async def get_claims(self) -> dict: + """ + Get the JWT claims from the current authentication token. + + Returns: + dict: The JWT claims. + + Raises: + ValueError: If token decoding fails. + """ + if self._current_claims: + return self._current_claims + + token = await self.get_token() + if not token: + return {} + + try: + import jwt + + logger.info("Decoding JWT claims from access token...") + + parsed_token = jwt.decode(token, algorithms=["RS256"], options={"verify_signature": False}) + self._current_claims = parsed_token + return self._current_claims + + except Exception as e: + raise ValueError(f"Failed to decode access token: {str(e)}") from e + async def _fetch_access_token_with_playwright(self) -> Optional[str]: """ Fetch access token using Playwright browser automation. diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 56709b4c4..c86165a65 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -8,7 +8,6 @@ from enum import IntEnum from typing import Optional -import jwt import websockets from pyrit.auth import CopilotAuthenticator @@ -68,7 +67,7 @@ class WebSocketCopilotTarget(PromptTarget): ``pip install playwright && playwright install chromium`` """ - SUPPORTED_DATA_TYPES = {"text"} # TODO: support more types? + SUPPORTED_DATA_TYPES = {"text"} WEBSOCKET_BASE_URL: str = "wss://substrate.office.com/m365Copilot/Chathub" RESPONSE_TIMEOUT_SECONDS: int = 60 @@ -115,7 +114,20 @@ def __init__( @staticmethod def _dict_to_websocket(data: dict) -> str: - # Produce the smallest possible JSON string, followed by record separator + """ + Convert a dictionary to WebSocket message format. + + SignalR protocol (used by Copilot) requires JSON messages terminated with + ASCII record separator (\\x1e). Minimal JSON formatting reduces bandwidth. + + https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding + + Args: + data (dict): The data to serialize. + + Returns: + str: JSON string with record separator appended. + """ return json.dumps(data, separators=(",", ":")) + "\x1e" @staticmethod @@ -132,8 +144,6 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: message type and extracted content. """ results: list[tuple[CopilotMessageType, str]] = [] - - # https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding messages = message.split("\x1e") # record separator for message in messages: @@ -170,20 +180,25 @@ def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: return results if results else [(CopilotMessageType.UNKNOWN, "")] async def _build_websocket_url_async(self, *, session_id: str, copilot_conversation_id: str) -> str: - access_token = await self._authenticator.get_token() + """ + Build the WebSocket URL with all the required authentication and session parameters. - try: - parsed_token = jwt.decode(access_token, algorithms=["RS256"], options={"verify_signature": False}) - except Exception as e: - raise ValueError(f"Failed to decode access token: {str(e)}") from e + Returns: + str: Complete WebSocket URL with authentication and parameters. + + Raises: + ValueError: If token cannot be decoded or required claims (tid, oid) are missing. + """ + access_token = await self._authenticator.get_token() + token_claims = await self._authenticator.get_claims() - tenant_id = parsed_token.get("tid") - object_id = parsed_token.get("oid") + tenant_id = token_claims.get("tid") + object_id = token_claims.get("oid") if not tenant_id or not object_id: raise ValueError( "Failed to extract tenant_id (tid) or object_id (oid) from bearer token. " - f"Token claims: {list(parsed_token.keys())}" + f"Token claims: {list(token_claims.keys())}" ) client_request_id = str(uuid.uuid4()) @@ -209,6 +224,15 @@ async def _build_websocket_url_async(self, *, session_id: str, copilot_conversat def _build_prompt_message( self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool ) -> dict: + """ + Construct the prompt message payload for Copilot WebSocket API. + + Builds a comprehensive message structure following Copilot's expected format, + including session metadata, feature flags, and the user's prompt text. + + Returns: + dict: The complete message payload ready to be sent via WebSocket. + """ request_id = trace_id = uuid.uuid4().hex result = { "arguments": [ @@ -281,13 +305,33 @@ def _build_prompt_message( async def _connect_and_send( self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool ) -> str: + """ + Establish WebSocket connection, send prompt, and await response. + + The method polls for messages, ignoring PARTIAL_RESPONSE streaming updates + until it receives either FINAL_CONTENT (success), STREAM_END, or UNKNOWN (error). + + Args: + prompt (str): The user prompt text to send. + session_id (str): Copilot session identifier. + copilot_conversation_id (str): Copilot conversation identifier. + is_start_of_session (bool): Whether this is the first message in the conversation. + + Returns: + str: The final response text from Copilot. + + Raises: + TimeoutError: If no response received within the specified timeout period. + RuntimeError: If WebSocket connection closes unexpectedly, protocol violation occurs, + or maximum message iterations exceeded. + """ websocket_url = await self._build_websocket_url_async( session_id=session_id, copilot_conversation_id=copilot_conversation_id ) inputs = [ - {"protocol": "json", "version": 1}, - self._build_prompt_message( + {"protocol": "json", "version": 1}, # the handshake message, we expect PING in response + self._build_prompt_message( # the actual user prompt, we expect FINAL_CONTENT in response prompt=prompt, session_id=session_id, copilot_conversation_id=copilot_conversation_id, @@ -307,10 +351,21 @@ async def _connect_and_send( is_user_input = input_msg.get("type") == CopilotMessageType.USER_PROMPT + MAX_MESSAGE_ITERATIONS = 1000 + iteration_count = 0 stop_polling = False + while not stop_polling: + # Prevent infinite loops (e.g. if Copilot somehow never sends a terminating message) + iteration_count += 1 + if iteration_count > MAX_MESSAGE_ITERATIONS: + raise RuntimeError( + f"Exceeded maximum message iterations ({MAX_MESSAGE_ITERATIONS}) " + "while waiting for Copilot response." + ) + try: - response = await asyncio.wait_for( + raw_message = await asyncio.wait_for( websocket.recv(), timeout=self._response_timeout_seconds, ) @@ -319,12 +374,12 @@ async def _connect_and_send( f"Timed out waiting for Copilot response after {self._response_timeout_seconds} seconds." ) - if response is None: + if raw_message is None: raise RuntimeError( "WebSocket connection closed unexpectedly: received None from websocket.recv()" ) - parsed_messages = self._parse_raw_message(response) + parsed_messages = self._parse_raw_message(raw_message) for msg_type, content in parsed_messages: if msg_type in ( @@ -333,18 +388,33 @@ async def _connect_and_send( CopilotMessageType.STREAM_END, ): stop_polling = True + # Not breaking here to process all messages in this batch, + # possibly including FINAL_CONTENT if msg_type == CopilotMessageType.FINAL_CONTENT: response = content elif msg_type == CopilotMessageType.UNKNOWN: logger.debug("Received unknown or empty message type.") + # PING is Copilot's acknowledgment of the protocol handshake (first message in inputs[]) + # It should arrive after the handshake, not after a user prompt + # If we're processing a user prompt and receive PING, something is wrong - ignore it + # and keep polling for the actual FINAL_CONTENT response elif msg_type == CopilotMessageType.PING and not is_user_input: stop_polling = True return response def _validate_request(self, *, message: Message) -> None: + """ + Validate that the message meets target requirements. + + Args: + message (Message): The message to validate. + + Raises: + ValueError: If message contains more than one piece or non-text content. + """ n_pieces = len(message.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") @@ -354,6 +424,18 @@ def _validate_request(self, *, message: Message) -> None: raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") def _is_start_of_session(self, *, conversation_id: str) -> bool: + """ + Determine if this is the first message in a PyRIT conversation. + + Checks memory for existing conversation history to set the appropriate + flag for Copilot's server-side conversation initialization. + + Args: + conversation_id (str): The PyRIT conversation ID. + + Returns: + bool: True if no prior messages exist in this conversation, False otherwise. + """ conversation_history = self._memory.get_conversation(conversation_id=conversation_id) return len(conversation_history) == 0 diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 1004b49f4..041bd4b07 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -21,6 +21,7 @@ def mock_authenticator(): mock_token = mock_token.decode("utf-8") authenticator = MagicMock(spec=CopilotAuthenticator) authenticator.get_token = AsyncMock(return_value=mock_token) + authenticator.get_claims = AsyncMock(return_value=token_payload) return authenticator @@ -171,6 +172,7 @@ async def test_build_websocket_url_with_missing_ids(self, mock_authenticator): if isinstance(mock_token, bytes): mock_token = mock_token.decode("utf-8") mock_authenticator.get_token = AsyncMock(return_value=mock_token) + mock_authenticator.get_claims = AsyncMock(return_value=token_payload) target = WebSocketCopilotTarget(authenticator=mock_authenticator) with pytest.raises(ValueError, match="Failed to extract tenant_id \\(tid\\) or object_id \\(oid\\)"): @@ -179,6 +181,7 @@ async def test_build_websocket_url_with_missing_ids(self, mock_authenticator): @pytest.mark.asyncio async def test_build_websocket_url_with_invalid_token(self, mock_authenticator): mock_authenticator.get_token = AsyncMock(return_value="invalid_token") + mock_authenticator.get_claims = AsyncMock(side_effect=ValueError("Failed to decode access token")) target = WebSocketCopilotTarget(authenticator=mock_authenticator) with pytest.raises(ValueError, match="Failed to decode access token"): From 7266d7e7375bd6c619f7da64eb2721d34a95ce8f Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Mon, 29 Dec 2025 23:29:44 +0100 Subject: [PATCH 36/38] AUTH FLOW improvements --- pyrit/auth/copilot_authenticator.py | 290 +++++++++++++++++----------- 1 file changed, 173 insertions(+), 117 deletions(-) diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index 33bb562f5..2c3a55ddb 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -4,7 +4,7 @@ import logging import asyncio import os -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta from typing import Optional import json @@ -21,8 +21,9 @@ class CopilotAuthenticator(Authenticator): """ Playwright-based authenticator for Microsoft Copilot. Used by WebSocketCopilotTarget. - This authenticator automates browser login to obtain and refresh access tokens that are necessary for accessing - Microsoft Copilot via WebSocket connections. It uses Playwright to simulate user interactions for authentication, and msal-extensions for encrypted token persistence. + This authenticator automates browser login to obtain and refresh access tokens that are necessary + for accessing Microsoft Copilot via WebSocket connections. It uses Playwright to simulate user + interactions for authentication, and msal-extensions for encrypted token persistence. An access token acquired by this authenticator is usually valid for about 60 minutes. @@ -35,14 +36,27 @@ class CopilotAuthenticator(Authenticator): ``pip install playwright && playwright install chromium``. """ + # TODO: ensure login with account with MFA enabled work correctly + + #: Name of the cache file to store tokens CACHE_FILE_NAME: str = "copilot_token_cache.bin" + #: Buffer before token expiry to avoid using tokens about to expire (in seconds) + EXPIRY_BUFFER_SECONDS: int = 300 + #: Default timeout for capturing token via network monitoring (in seconds) + DEFAULT_TOKEN_CAPTURE_TIMEOUT: int = 60 + #: Default timeout for waiting on page elements (in seconds) + DEFAULT_ELEMENT_TIMEOUT_SECONDS: int = 10 + #: Number of retries for network operations + DEFAULT_NETWORK_RETRIES: int = 3 def __init__( self, *, headless: bool = False, maximized: bool = True, - timeout_for_elements: int = 10, + timeout_for_elements_seconds: int = DEFAULT_ELEMENT_TIMEOUT_SECONDS, + token_capture_timeout_seconds: int = DEFAULT_TOKEN_CAPTURE_TIMEOUT, + network_retries: int = DEFAULT_NETWORK_RETRIES, fallback_to_plaintext: bool = False, ): """ @@ -51,9 +65,12 @@ def __init__( Args: headless (bool): Whether to run the browser in headless mode. Default is False. maximized (bool): Whether to start the browser maximized. Default is True. - timeout_for_elements (int): Timeout used when waiting for page elements, in seconds. Default is 10. + timeout_for_elements_seconds (int): Timeout used when waiting for page elements, in seconds. + token_capture_timeout_seconds (int): Maximum time to wait for token capture via network monitoring. + network_retries (int): Number of retry attempts for network operations. Default is 3. fallback_to_plaintext (bool): Whether to fallback to plaintext storage if encryption is unavailable. If set to False (default), an exception will be raised if encryption cannot be used. + WARNING: Setting to True stores tokens in plaintext. Raises: ValueError: If the required environment variables are not set. @@ -65,7 +82,9 @@ def __init__( self._headless = headless self._maximized = maximized - self._timeout = timeout_for_elements * 1000 # ms + self._elements_timeout = timeout_for_elements_seconds * 1000 + self._token_capture_timeout = token_capture_timeout_seconds + self._network_retries = network_retries self._fallback_to_plaintext = fallback_to_plaintext self._cache_dir = PYRIT_CACHE_PATH @@ -76,10 +95,80 @@ def __init__( raise ValueError("COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set.") self._token_cache = self._create_persistent_cache(self._cache_file, self._fallback_to_plaintext) + self._current_claims = {} # for easy access to claims without re-decoding token + + # Lock to prevent concurrent token fetches from launching multiple browsers + self._token_fetch_lock = asyncio.Lock() + + async def refresh_token(self) -> str: + """ + Refresh the authentication token asynchronously. + + This will clear the existing token cache and fetch a new token with automated browser login. + + Returns: + str: The refreshed authentication token. + + Raises: + RuntimeError: If token refresh fails. + """ + logger.info("Refreshing access token...") + self._clear_token_cache() self._current_claims = {} + token = await self._fetch_access_token_with_playwright() + + if not token: + raise RuntimeError("Failed to refresh access token.") + + return token + + async def get_token(self) -> str: + """ + Get the current authentication token. + + This checks the cache first and only launches the browser if no valid token is found. + If multiple calls are made concurrently, they will be serialized via an asyncio lock + to prevent launching multiple browser instances. + + Returns: + str: A valid Bearer token for Microsoft Copilot. + """ + async with self._token_fetch_lock: + cached_token = await self._get_cached_token_if_available_and_valid() + if cached_token and "access_token" in cached_token: + logger.info("Using cached access token.") + return cached_token["access_token"] + + logger.info("No valid cached token found. Initiating browser authentication.") + return await self.refresh_token() + + async def get_claims(self) -> dict: + """ + Get the JWT claims from the current authentication token. + + Returns: + dict: The JWT claims decoded from the access token. + """ + return self._current_claims or {} @staticmethod def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = False): + """ + Create a persistent cache for token storage with encryption. + + Uses msal-extensions to provide encrypted storage. Falls back to plaintext + only if explicitly allowed and encryption is unavailable. + + Args: + cache_file: Path to the cache file. + fallback_to_plaintext: Whether to allow plaintext fallback. + + Returns: + A persistence object (encrypted or plaintext). + + Raises: + Exception: If encryption fails and fallback is not allowed. + """ # https://github.com/AzureAD/microsoft-authentication-extensions-for-python try: @@ -87,13 +176,23 @@ def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = Fals return build_encrypted_persistence(cache_file) except Exception as e: if fallback_to_plaintext: - logger.warning(f"Encryption unavailable ({e}). Opting in to plain text.") + logger.warning(f"Encryption unavailable ({e}). Falling back to PLAINTEXT storage.") return FilePersistence(cache_file) - logger.error("Encryption unavailable and fallback_to_plaintext is False.") + logger.error(f"Encryption unavailable ({e}) and fallback_to_plaintext is False. Cannot proceed.") raise - def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: - # TODO: make sure the cached token matches the proper user account + async def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: + """ + Retrieve and validate cached token. + + Validates that: + - Token exists and is properly formatted. + - Token belongs to the current user (username match). + - Token has not expired (with safety buffer). + + Returns: + Token data dictionary if valid, None otherwise. + """ try: cache_data = self._token_cache.load() if not cache_data: @@ -105,14 +204,30 @@ def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: logger.info("No access token in cache.") return None + cached_user = token_data.get("claims").get("upn") + if not cached_user: + logger.info("No user associated with cached token. Token invalidated.") + return None + elif cached_user != self._username: + logger.info( + f"Cached token is for different user (cached: {cached_user}, current: {self._username}). " + "Token invalidated." + ) + return None + expires_at = token_data.get("expires_at") if expires_at: expiry_time = datetime.fromtimestamp(expires_at, tz=timezone.utc) current_time = datetime.now(timezone.utc) - # TODO: add n-minute buffer to avoid using tokens about to expire - if current_time >= expiry_time: - logger.info("Cached token has expired.") + # This should prevent most mid-request failures due to token expiration + expiry_with_buffer = expiry_time - timedelta(seconds=self.EXPIRY_BUFFER_SECONDS) + if current_time >= expiry_with_buffer: + minutes_until_expiry = (expiry_time - current_time).total_seconds() / 60 + logger.info( + f"Cached token expires in {minutes_until_expiry:.2f} minutes, " + f"within {self.EXPIRY_BUFFER_SECONDS}s safety buffer. Token invalidated." + ) return None minutes_left = (expiry_time - current_time).total_seconds() / 60 @@ -129,9 +244,27 @@ def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: return None def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None: + """ + Save token to persistent cache with metadata. + + Args: + token: The access token to cache. + expires_in: Token lifetime in seconds (optional). + """ + self._current_claims = {} + + try: + import jwt + + self._current_claims = jwt.decode(token, algorithms=["RS256"], options={"verify_signature": False}) + + except Exception as e: + logger.error(f"Failed to decode token for caching: {e}") + token_data = { "access_token": token, "token_type": "Bearer", + "claims": self._current_claims, "cached_at": datetime.now(timezone.utc).timestamp(), } @@ -153,95 +286,23 @@ def _clear_token_cache(self) -> None: except Exception as e: logger.error(f"Failed to clear cache: {e}") - async def refresh_token(self) -> str: - """ - Refresh the authentication token asynchronously. - - This will clear the existing token cache and fetch a new token with automated browser login. - - Returns: - str: The refreshed authentication token. - - Raises: - RuntimeError: If token refresh fails. - """ - logger.info("Refreshing access token...") - self._clear_token_cache() - self._current_claims = {} - token = await self._fetch_access_token_with_playwright() - - if not token: - raise RuntimeError("Failed to refresh access token.") - - return token - - async def get_token(self) -> str: - """ - Get the current authentication token. - - This will check the cache first and only launch the browser if no valid token is found. - - Returns: - str: The current authentication token. - - Raises: - RuntimeError: If token retrieval fails. - """ - # TODO: make sure multiple concurrent calls don't launch multiple browsers - cached_token = self._get_cached_token_if_available_and_valid() - if cached_token and "access_token" in cached_token: - logger.info("Using cached access token.") - return cached_token["access_token"] - - logger.info("No valid cached token found.") - return await self.refresh_token() - - async def get_claims(self) -> dict: - """ - Get the JWT claims from the current authentication token. - - Returns: - dict: The JWT claims. - - Raises: - ValueError: If token decoding fails. - """ - if self._current_claims: - return self._current_claims - - token = await self.get_token() - if not token: - return {} - - try: - import jwt - - logger.info("Decoding JWT claims from access token...") - - parsed_token = jwt.decode(token, algorithms=["RS256"], options={"verify_signature": False}) - self._current_claims = parsed_token - return self._current_claims - - except Exception as e: - raise ValueError(f"Failed to decode access token: {str(e)}") from e - async def _fetch_access_token_with_playwright(self) -> Optional[str]: """ Fetch access token using Playwright browser automation. - Raises: - RuntimeError: If Playwright is not installed. - Returns: - Optional[str]: The bearer token if successfully retrieved, else None. + Optional[str]: The bearer token if successfully retrieved, None otherwise. + + Raises: + RuntimeError: If Playwright is not installed or browser launch fails. """ - # TODO: it's a long function, maybe split into smaller ones? try: from playwright.async_api import async_playwright - - pass except ImportError: - raise RuntimeError("Playwright is not installed. Please install it with 'pip install playwright'.") + raise RuntimeError( + "Playwright is not installed. Please install it with: " + "'pip install playwright && playwright install chromium'" + ) bearer_token = None token_expires_in = None @@ -278,20 +339,10 @@ async def response_handler(response): if "access_token" in data: bearer_token = data["access_token"] token_expires_in = data.get("expires_in") + logger.info("Captured bearer token from JSON response.") - except json.JSONDecodeError: - logger.info("Response JSON decode failed, trying regex extraction...") - - match = re.search(r'"access_token"\s*:\s*"([^"]+)"', text) - if match: - bearer_token = match.group(1) - logger.info("Captured bearer token using regex.") - - expires_match = re.search(r'"expires_in"\s*:\s*(\d+)', text) - if expires_match: - token_expires_in = int(expires_match.group(1)) - else: - logger.error("Failed to extract bearer token using regex.") + except Exception as e: + logger.error(f"Error parsing JSON token response: {e}") except Exception as e: logger.error(f"Error reading response: {e}") @@ -307,21 +358,21 @@ async def response_handler(response): await page.goto("https://www.office.com/") logger.info("Waiting for profile icon...") - await page.wait_for_selector("#mectrl_headerPicture", timeout=self._timeout) + await page.wait_for_selector("#mectrl_headerPicture", timeout=self._elements_timeout) await page.click("#mectrl_headerPicture") logger.info("Waiting for email input...") - await page.wait_for_selector("#i0116", timeout=self._timeout) + await page.wait_for_selector("#i0116", timeout=self._elements_timeout) await page.fill("#i0116", self._username) await page.click("#idSIButton9") logger.info("Waiting for password input...") - await page.wait_for_selector("#i0118", timeout=self._timeout) + await page.wait_for_selector("#i0118", timeout=self._elements_timeout) await page.fill("#i0118", self._password) await page.click("#idSIButton9") logger.info("Waiting for 'Stay signed in?' prompt...") - await page.wait_for_selector("#idSIButton9", timeout=self._timeout) + await page.wait_for_selector("#idSIButton9", timeout=self._elements_timeout) logger.info("Clicking 'Yes' to stay signed in...") await page.click("#idSIButton9") @@ -329,12 +380,13 @@ async def response_handler(response): logger.info("Navigating to Copilot...") logger.info("Waiting for Copilot button and clicking it...") - await page.wait_for_selector('div[aria-label="M365 Copilot"]', timeout=self._timeout) - await page.click('div[aria-label="M365 Copilot"]', timeout=self._timeout) + await page.wait_for_selector('div[aria-label="M365 Copilot"]', timeout=self._elements_timeout) + await page.click('div[aria-label="M365 Copilot"]', timeout=self._elements_timeout) - logger.info("Waiting 60 seconds for bearer token to be captured...") - for _ in range(60): + logger.info(f"Waiting up to {self._token_capture_timeout}s for bearer token to be captured...") + for elapsed in range(self._token_capture_timeout): if bearer_token: + logger.info(f"Token captured after {elapsed}s") break await asyncio.sleep(1) @@ -344,16 +396,20 @@ async def response_handler(response): ) self._save_token_to_cache(token=bearer_token, expires_in=token_expires_in) else: - logger.error("Failed to retrieve bearer token within 60 seconds.") + logger.error(f"Failed to retrieve bearer token within {self._token_capture_timeout} seconds.") return bearer_token except Exception as e: logger.error("Failed to retrieve access token using Playwright.") - if str(e).startswith("BrowserType.launch"): + if "BrowserType.launch" in str(e): logger.error("Playwright browser launch failed. Did you run 'playwright install chromium'?") else: - logger.error(f"Error details: {e}") + # Sanitize error message to avoid leaking sensitive info + error_msg = str(e) + if self._password and self._password in error_msg: + error_msg = error_msg.replace(self._password, "******") + logger.error(f"Error details: {error_msg}") return None finally: From 52411608dc34c232275c694befbdb9ed13539721 Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 30 Dec 2025 11:46:00 +0100 Subject: [PATCH 37/38] `CopilotAuthenticator` tests --- pyrit/auth/copilot_authenticator.py | 1 - tests/unit/auth/__init__.py | 2 + tests/unit/auth/test_copilot_authenticator.py | 853 ++++++++++++++++++ 3 files changed, 855 insertions(+), 1 deletion(-) create mode 100644 tests/unit/auth/__init__.py create mode 100644 tests/unit/auth/test_copilot_authenticator.py diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index 2c3a55ddb..256f084b1 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -8,7 +8,6 @@ from typing import Optional import json -import re from msal_extensions import build_encrypted_persistence, FilePersistence from pyrit.auth.authenticator import Authenticator diff --git a/tests/unit/auth/__init__.py b/tests/unit/auth/__init__.py new file mode 100644 index 000000000..9a0454564 --- /dev/null +++ b/tests/unit/auth/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/auth/test_copilot_authenticator.py b/tests/unit/auth/test_copilot_authenticator.py new file mode 100644 index 000000000..1eb890b75 --- /dev/null +++ b/tests/unit/auth/test_copilot_authenticator.py @@ -0,0 +1,853 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import os +import tempfile +from datetime import datetime, timezone, timedelta +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.auth.copilot_authenticator import CopilotAuthenticator + + +@pytest.fixture +def mock_env_vars(): + """Mock required environment variables.""" + + with patch.dict( + os.environ, + { + "COPILOT_USERNAME": "test@example.com", + "COPILOT_PASSWORD": "test_password_123", + }, + ): + yield + + +@pytest.fixture +def mock_persistent_cache(): + """Mock msal-extensions persistence.""" + + mock_cache = MagicMock() + mock_cache.load.return_value = None + mock_cache.save.return_value = None + return mock_cache + + +@pytest.fixture +def temp_cache_dir(): + """Create a temporary directory for cache files.""" + + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +class TestCopilotAuthenticatorConstants: + """Test class-level constants.""" + + def test_class_constants_defined(self): + assert hasattr(CopilotAuthenticator, "CACHE_FILE_NAME") + assert hasattr(CopilotAuthenticator, "EXPIRY_BUFFER_SECONDS") + assert hasattr(CopilotAuthenticator, "DEFAULT_TOKEN_CAPTURE_TIMEOUT") + assert hasattr(CopilotAuthenticator, "DEFAULT_ELEMENT_TIMEOUT_SECONDS") + assert hasattr(CopilotAuthenticator, "DEFAULT_NETWORK_RETRIES") + + def test_constant_values(self): + assert CopilotAuthenticator.CACHE_FILE_NAME == "copilot_token_cache.bin" + assert CopilotAuthenticator.EXPIRY_BUFFER_SECONDS == 300 + assert CopilotAuthenticator.DEFAULT_TOKEN_CAPTURE_TIMEOUT == 60 + assert CopilotAuthenticator.DEFAULT_ELEMENT_TIMEOUT_SECONDS == 10 + assert CopilotAuthenticator.DEFAULT_NETWORK_RETRIES == 3 + + +class TestCopilotAuthenticatorInitialization: + """Test CopilotAuthenticator initialization scenarios.""" + + def test_init_with_required_env_vars(self, mock_env_vars, mock_persistent_cache): + """Test successful initialization with required environment variables.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + + assert authenticator._username == "test@example.com" + assert authenticator._password == "test_password_123" + assert authenticator._headless is False + assert authenticator._maximized is True + assert authenticator._elements_timeout == 10000 + assert authenticator._token_capture_timeout == 60 + assert authenticator._network_retries == 3 + assert authenticator._fallback_to_plaintext is False + + assert isinstance(authenticator._token_fetch_lock, asyncio.Lock) + assert authenticator._current_claims == {} + assert authenticator._token_cache is not None + + def test_init_with_custom_parameters(self, mock_env_vars, mock_persistent_cache): + """Test initialization with custom parameters.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator( + headless=True, + maximized=False, + timeout_for_elements_seconds=20, + token_capture_timeout_seconds=120, + network_retries=5, + fallback_to_plaintext=True, + ) + + assert authenticator._headless is True + assert authenticator._maximized is False + assert authenticator._elements_timeout == 20000 + assert authenticator._token_capture_timeout == 120 + assert authenticator._network_retries == 5 + assert authenticator._fallback_to_plaintext is True + + def test_init_missing_env_var_raises_error(self): + """Test that missing a required environment variable raises ValueError.""" + + for missing_var in ["COPILOT_USERNAME", "COPILOT_PASSWORD"]: + env_vars = { + "COPILOT_USERNAME": "test@example.com", + "COPILOT_PASSWORD": "test_password_123", + } + env_vars.pop(missing_var) + with patch.dict(os.environ, env_vars, clear=True): + with pytest.raises( + ValueError, match="COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set" + ): + CopilotAuthenticator() + + def test_init_creates_cache_directory(self, mock_env_vars, mock_persistent_cache, temp_cache_dir): + """Test that initialization creates cache directory if it doesn't exist.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("pyrit.auth.copilot_authenticator.PYRIT_CACHE_PATH", temp_cache_dir / "new_cache"), + ): + authenticator = CopilotAuthenticator() + + assert (temp_cache_dir / "new_cache").exists() + assert authenticator._cache_dir == temp_cache_dir / "new_cache" + + +class TestCopilotAuthenticatorCacheManagement: + """Test cache creation, loading, and saving functionality.""" + + def test_create_persistent_cache_with_encryption(self): + """Test cache creation with encryption enabled.""" + + mock_encrypted_cache = MagicMock() + + with patch( + "pyrit.auth.copilot_authenticator.build_encrypted_persistence", + return_value=mock_encrypted_cache, + ) as mock_build: + result = CopilotAuthenticator._create_persistent_cache("/test/cache.bin", fallback_to_plaintext=False) + + mock_build.assert_called_once_with("/test/cache.bin") + assert result == mock_encrypted_cache + + def test_create_persistent_cache_fallback_to_plaintext(self): + """Test cache creation falls back to plaintext when encryption fails.""" + + mock_plaintext_cache = MagicMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.build_encrypted_persistence", + side_effect=Exception("Encryption not available"), + ), + patch( + "pyrit.auth.copilot_authenticator.FilePersistence", + return_value=mock_plaintext_cache, + ) as mock_file_persistence, + ): + result = CopilotAuthenticator._create_persistent_cache("/test/cache.bin", fallback_to_plaintext=True) + + mock_file_persistence.assert_called_once_with("/test/cache.bin") + assert result == mock_plaintext_cache + + def test_create_persistent_cache_raises_on_encryption_failure_without_fallback(self): + """Test cache creation raises exception when encryption fails and no fallback.""" + + with patch( + "pyrit.auth.copilot_authenticator.build_encrypted_persistence", + side_effect=Exception("Encryption not available"), + ): + with pytest.raises(Exception, match="Encryption not available"): + CopilotAuthenticator._create_persistent_cache("/test/cache.bin", fallback_to_plaintext=False) + + def test_save_token_to_cache_with_expiry(self, mock_env_vars, mock_persistent_cache): + """Test saving token to cache with expiration time.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("jwt.decode") as mock_jwt_decode, + ): + mock_claims = {"upn": "test@example.com", "aud": "sydney"} + mock_jwt_decode.return_value = mock_claims + + authenticator = CopilotAuthenticator() + test_token = "test.jwt.token" + + authenticator._save_token_to_cache(token=test_token, expires_in=3600) + + assert mock_persistent_cache.save.called + saved_data = json.loads(mock_persistent_cache.save.call_args[0][0]) + + assert saved_data["access_token"] == test_token + assert saved_data["token_type"] == "Bearer" + assert saved_data["claims"] == mock_claims + assert saved_data["expires_in"] == 3600 + assert "expires_at" in saved_data + assert "cached_at" in saved_data + assert authenticator._current_claims == mock_claims + + def test_save_token_to_cache_without_expiry(self, mock_env_vars, mock_persistent_cache): + """Test saving token to cache without expiration time.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("jwt.decode") as mock_jwt_decode, + ): + mock_claims = {"upn": "test@example.com"} + mock_jwt_decode.return_value = mock_claims + + authenticator = CopilotAuthenticator() + test_token = "test.jwt.token" + + authenticator._save_token_to_cache(token=test_token, expires_in=None) + + saved_data = json.loads(mock_persistent_cache.save.call_args[0][0]) + + assert "expires_in" not in saved_data + assert "expires_at" not in saved_data + assert saved_data["access_token"] == test_token + + def test_save_token_handles_jwt_decode_failure(self, mock_env_vars, mock_persistent_cache): + """Test that save_token handles JWT decode failures gracefully.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("jwt.decode", side_effect=Exception("Invalid JWT")), + patch("pyrit.auth.copilot_authenticator.logger") as mock_logger, + ): + authenticator = CopilotAuthenticator() + test_token = "invalid.jwt.token" + + authenticator._save_token_to_cache(token=test_token, expires_in=3600) + mock_logger.error.assert_called_with("Failed to decode token for caching: Invalid JWT") + + saved_data = json.loads(mock_persistent_cache.save.call_args[0][0]) + assert saved_data["access_token"] == test_token + assert saved_data["claims"] == {} + + def test_clear_token_cache(self, mock_env_vars, mock_persistent_cache): + """Test clearing the token cache.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + authenticator._clear_token_cache() + mock_persistent_cache.save.assert_called_with(json.dumps({})) + + def test_clear_token_cache_handles_error(self, mock_env_vars, mock_persistent_cache): + """Test that clear_token_cache handles errors gracefully.""" + + mock_persistent_cache.save.side_effect = Exception("Cache error") + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("pyrit.auth.copilot_authenticator.logger") as mock_logger, + ): + authenticator = CopilotAuthenticator() + authenticator._clear_token_cache() + mock_logger.error.assert_called_with("Failed to clear cache: Cache error") + + +class TestCopilotAuthenticatorCachedTokenRetrieval: + """Test cached token validation and retrieval.""" + + def test_get_cached_token_valid(self, mock_env_vars, mock_persistent_cache): + """Test retrieving valid cached token.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "cached.token.value", + "token_type": "Bearer", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + "cached_at": datetime.now(timezone.utc).timestamp(), + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is not None + assert result["access_token"] == "cached.token.value" + + def test_get_cached_token_expired(self, mock_env_vars, mock_persistent_cache): + """Test that expired token is not returned.""" + + expires_at = (datetime.now(timezone.utc) - timedelta(minutes=5)).timestamp() + cached_data = { + "access_token": "expired.token.value", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_within_expiry_buffer(self, mock_env_vars, mock_persistent_cache): + """Test that token within expiry buffer is not returned.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(seconds=200)).timestamp() + cached_data = { + "access_token": "soon.to.expire", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None # default buffer is 300 seconds, so should return None + + def test_get_cached_token_no_cache_file(self, mock_env_vars, mock_persistent_cache): + """Test behavior when cache file doesn't exist.""" + + mock_persistent_cache.load.return_value = None + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_wrong_user(self, mock_env_vars, mock_persistent_cache): + """Test that cached token for different user is invalidated.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "other.user.token", + "claims": {"upn": "different@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_no_upn_in_claims(self, mock_env_vars, mock_persistent_cache): + """Test that cached token without upn claim is invalidated.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "token.without.upn", + "claims": {"aud": "sydney"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_missing_access_token(self, mock_env_vars, mock_persistent_cache): + """Test that cache data without access_token is invalid.""" + + cached_data = { + "token_type": "Bearer", + "claims": {"upn": "test@example.com"}, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_invalid_json(self, mock_env_vars, mock_persistent_cache): + """Test handling of corrupted cache data.""" + + mock_persistent_cache.load.return_value = "invalid json {{" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + +class TestCopilotAuthenticatorTokenRetrieval: + """Test token retrieval via get_token method.""" + + @pytest.mark.asyncio + async def test_get_token_uses_cached_token(self, mock_env_vars, mock_persistent_cache): + """Test that get_token uses cached token when available.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "cached.valid.token", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + token = await authenticator.get_token() + assert token == "cached.valid.token" + + @pytest.mark.asyncio + async def test_get_token_fetches_new_when_no_cache(self, mock_env_vars, mock_persistent_cache): + """Test that get_token fetches new token when cache is empty.""" + + mock_persistent_cache.load.return_value = None + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="new.fetched.token", + ) as mock_fetch, + ): + authenticator = CopilotAuthenticator() + token = await authenticator.get_token() + mock_fetch.assert_called_once() + assert token == "new.fetched.token" + + @pytest.mark.asyncio + async def test_get_token_serializes_concurrent_requests(self, mock_env_vars, mock_persistent_cache): + """Test that concurrent get_token calls are serialized via lock.""" + + fetch_call_count = 0 + mock_persistent_cache.load.return_value = None # start with no cache + + async def mock_fetch(): + nonlocal fetch_call_count + fetch_call_count += 1 + await asyncio.sleep(0.01) # minimal delay to test concurrency + return f"token.{fetch_call_count}" + + def mock_load_side_effect(): + # After first fetch, return cached token for subsequent calls + if fetch_call_count > 0: + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + return json.dumps( + { + "access_token": "token.1", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + ) + return None + + mock_persistent_cache.load.side_effect = mock_load_side_effect + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + side_effect=mock_fetch, + ), + patch("jwt.decode", return_value={"upn": "test@example.com"}), + ): + authenticator = CopilotAuthenticator() + + results = await asyncio.gather( + authenticator.get_token(), + authenticator.get_token(), + authenticator.get_token(), + ) + + # Only one fetch should have occurred due to lock + caching + assert fetch_call_count == 1 + assert results[0] == results[1] == results[2] == "token.1" + + +class TestCopilotAuthenticatorTokenRefresh: + """Test token refresh functionality.""" + + @pytest.mark.asyncio + async def test_refresh_token_clears_cache(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token clears existing cache.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="refreshed.token", + ), + ): + authenticator = CopilotAuthenticator() + await authenticator.refresh_token() + assert any(json.dumps({}) in str(call) for call in mock_persistent_cache.save.call_args_list) + + @pytest.mark.asyncio + async def test_refresh_token_fetches_new_token(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token fetches new token.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="refreshed.token", + ) as mock_fetch, + ): + authenticator = CopilotAuthenticator() + token = await authenticator.refresh_token() + mock_fetch.assert_called_once() + assert token == "refreshed.token" + + @pytest.mark.asyncio + async def test_refresh_token_raises_on_failure(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token raises exception when fetch fails.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value=None, + ), + ): + authenticator = CopilotAuthenticator() + with pytest.raises(RuntimeError, match="Failed to refresh access token"): + await authenticator.refresh_token() + + @pytest.mark.asyncio + async def test_refresh_token_clears_current_claims(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token clears current claims.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="refreshed.token", + ), + ): + authenticator = CopilotAuthenticator() + authenticator._current_claims = {"old": "claims"} + await authenticator.refresh_token() + assert authenticator._current_claims == {} + + +class TestCopilotAuthenticatorGetClaims: + """Test JWT claims retrieval.""" + + @pytest.mark.asyncio + async def test_get_claims_returns_current_claims(self, mock_env_vars, mock_persistent_cache): + """Test that get_claims returns current claims.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + test_claims = {"upn": "test@example.com", "aud": "sydney"} + authenticator._current_claims = test_claims + claims = await authenticator.get_claims() + assert claims == test_claims + + @pytest.mark.asyncio + async def test_get_claims_returns_empty_dict_when_no_claims(self, mock_env_vars, mock_persistent_cache): + """Test that get_claims returns empty dict when no claims set.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + claims = await authenticator.get_claims() + assert claims == {} + + +class TestCopilotAuthenticatorPlaywrightIntegration: + """Test Playwright browser automation (mocked).""" + + @pytest.mark.asyncio + async def test_fetch_token_playwright_not_installed(self, mock_env_vars, mock_persistent_cache): + """Test that RuntimeError is raised when Playwright is not installed.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch.dict("sys.modules", {"playwright.async_api": None}), + ): + authenticator = CopilotAuthenticator() + with pytest.raises(RuntimeError, match="Playwright is not installed"): + await authenticator._fetch_access_token_with_playwright() + + @pytest.mark.asyncio + async def test_fetch_token_with_playwright_success(self, mock_env_vars, mock_persistent_cache): + """Test successful token fetch with Playwright.""" + + # Setup mock Playwright objects + mock_page = AsyncMock() + mock_context = AsyncMock() + mock_browser = AsyncMock() + mock_playwright = AsyncMock() + + # Setup mock browser hierarchy + mock_playwright.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_browser.close = AsyncMock() + mock_context.close = AsyncMock() + + # Mock page methods + mock_page.goto = AsyncMock() + mock_page.wait_for_selector = AsyncMock() + mock_page.fill = AsyncMock() + + # Create a response that will trigger the token capture + mock_response = AsyncMock() + mock_response.url = "https://login.microsoftonline.com/oauth2/v2.0/token" + mock_response.text = AsyncMock( + return_value=( + '{"access_token":"captured.bearer.token","token_type":"Bearer","expires_in":3600,"resource":"sydney"}' + ) + ) + + response_handler = None + + def capture_handler(event, handler): + """Capture the response handler when page.on() is called.""" + nonlocal response_handler + if event == "response": + response_handler = handler + + mock_page.on = MagicMock(side_effect=capture_handler) + + async def trigger_response_on_click(*args, **kwargs): + """Trigger the response handler when click is called.""" + if response_handler: + await response_handler(mock_response) + + mock_page.click = AsyncMock(side_effect=trigger_response_on_click) + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + patch("jwt.decode", return_value={"upn": "test@example.com"}), + ): + authenticator = CopilotAuthenticator() + token = await authenticator._fetch_access_token_with_playwright() + + assert token == "captured.bearer.token" + mock_browser.close.assert_called_once() + mock_context.close.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_token_handles_browser_launch_failure(self, mock_env_vars, mock_persistent_cache): + """Test handling of browser launch failure.""" + + mock_playwright = AsyncMock() + mock_playwright.chromium.launch = AsyncMock(side_effect=Exception("BrowserType.launch failed")) + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + ): + authenticator = CopilotAuthenticator() + token = await authenticator._fetch_access_token_with_playwright() + assert token is None + + @pytest.mark.asyncio + async def test_fetch_token_sanitizes_password_in_errors(self, mock_env_vars, mock_persistent_cache): + """Test that password is sanitized in error messages.""" + + mock_playwright = AsyncMock() + error_with_password = Exception(f"Login failed with password: test_password_123") + mock_playwright.chromium.launch = AsyncMock(side_effect=error_with_password) + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + patch("pyrit.auth.copilot_authenticator.logger") as mock_logger, + ): + authenticator = CopilotAuthenticator() + await authenticator._fetch_access_token_with_playwright() + + # Verify password was sanitized in error log + logged_messages = [str(call) for call in mock_logger.error.call_args_list] + assert any("******" in msg for msg in logged_messages) + assert not any("test_password_123" in msg for msg in logged_messages) + + @pytest.mark.asyncio + async def test_fetch_token_timeout_waiting_for_token(self, mock_env_vars, mock_persistent_cache): + """Test timeout when waiting for token capture.""" + + mock_page = AsyncMock() + mock_context = AsyncMock() + mock_browser = AsyncMock() + mock_playwright = AsyncMock() + + mock_playwright.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_browser.close = AsyncMock() + mock_context.close = AsyncMock() + + mock_page.goto = AsyncMock() + mock_page.wait_for_selector = AsyncMock() + mock_page.click = AsyncMock() + mock_page.fill = AsyncMock() + mock_page.on = MagicMock() + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + patch("asyncio.sleep", new_callable=AsyncMock), # mock sleep to speed up test + ): + authenticator = CopilotAuthenticator(token_capture_timeout_seconds=1) + token = await authenticator._fetch_access_token_with_playwright() + assert token is None + mock_browser.close.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_token_closes_browser_on_exception(self, mock_env_vars, mock_persistent_cache): + """Test that browser is closed even when exception occurs.""" + + mock_page = AsyncMock() + mock_context = AsyncMock() + mock_browser = AsyncMock() + mock_playwright = AsyncMock() + + mock_playwright.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_browser.close = AsyncMock() + mock_context.close = AsyncMock() + + mock_page.goto = AsyncMock(side_effect=Exception("Navigation failed")) + mock_page.on = MagicMock() # page.on is synchronous in Playwright + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + ): + authenticator = CopilotAuthenticator() + token = await authenticator._fetch_access_token_with_playwright() + assert token is None + mock_context.close.assert_called_once() + mock_browser.close.assert_called_once() From dfa0acefbc89958945f2288c5bd51d7405d3c3ee Mon Sep 17 00:00:00 2001 From: Paulina Kalicka <71526180+paulinek13@users.noreply.github.com> Date: Tue, 30 Dec 2025 15:01:31 +0100 Subject: [PATCH 38/38] various fixes --- pyrit/auth/copilot_authenticator.py | 10 ++--- tests/unit/auth/test_copilot_authenticator.py | 2 +- .../target/test_websocket_copilot_target.py | 44 +++++++++++++++++++ websocket_copilot_simple_example.py | 9 ++-- 4 files changed, 52 insertions(+), 13 deletions(-) diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index 256f084b1..d5fa08c41 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import logging import asyncio +import json +import logging import os -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional -import json -from msal_extensions import build_encrypted_persistence, FilePersistence +from msal_extensions import FilePersistence, build_encrypted_persistence from pyrit.auth.authenticator import Authenticator from pyrit.common.path import PYRIT_CACHE_PATH @@ -35,8 +35,6 @@ class CopilotAuthenticator(Authenticator): ``pip install playwright && playwright install chromium``. """ - # TODO: ensure login with account with MFA enabled work correctly - #: Name of the cache file to store tokens CACHE_FILE_NAME: str = "copilot_token_cache.bin" #: Buffer before token expiry to avoid using tokens about to expire (in seconds) diff --git a/tests/unit/auth/test_copilot_authenticator.py b/tests/unit/auth/test_copilot_authenticator.py index 1eb890b75..b4252ade8 100644 --- a/tests/unit/auth/test_copilot_authenticator.py +++ b/tests/unit/auth/test_copilot_authenticator.py @@ -5,7 +5,7 @@ import json import os import tempfile -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py index 041bd4b07..18356f1a1 100644 --- a/tests/unit/target/test_websocket_copilot_target.py +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -308,6 +308,50 @@ async def test_connect_and_send_none_response(self, mock_authenticator): is_start_of_session=True, ) + @pytest.mark.asyncio + async def test_connect_and_send_stream_end_without_final_content(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock( + side_effect=[ + '{"type":6}\x1e', + '{"type":3}\x1e', + ] + ) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + response = await target._connect_and_send( + prompt="Hello", + session_id="sid", + copilot_conversation_id="cid", + is_start_of_session=True, + ) + + assert response == "" + + @pytest.mark.asyncio + async def test_connect_and_send_exceeds_max_iterations(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(return_value='{"type":1}\x1e') + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + with pytest.raises(RuntimeError, match="Exceeded maximum message iterations"): + await target._connect_and_send( + prompt="Hello", + session_id="sid", + copilot_conversation_id="cid", + is_start_of_session=True, + ) + @pytest.mark.usefixtures("patch_central_database") class TestValidateRequest: diff --git a/websocket_copilot_simple_example.py b/websocket_copilot_simple_example.py index 78a729ee7..92d0097ec 100644 --- a/websocket_copilot_simple_example.py +++ b/websocket_copilot_simple_example.py @@ -1,17 +1,14 @@ """ -# TODO -THIS WILL BE REMOVED after proper unit tests are in place :) +# TODO: add notebook example instead of this """ import asyncio +import logging -from pyrit.executor.attack import MultiPromptSendingAttack +from pyrit.executor.attack import ConsoleAttackResultPrinter, MultiPromptSendingAttack from pyrit.models import Message, MessagePiece from pyrit.prompt_target import WebSocketCopilotTarget from pyrit.setup import IN_MEMORY, initialize_pyrit_async -from pyrit.executor.attack import ConsoleAttackResultPrinter - -import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO)