diff --git a/.gitignore b/.gitignore index 24e00cf0b..2d6f8bfef 100644 --- a/.gitignore +++ b/.gitignore @@ -169,6 +169,7 @@ cython_debug/ # PyRIT secrets file .env +.pyrit_cache/ # Cache for generating docs doc/generate_docs/cache/* diff --git a/doc/api.rst b/doc/api.rst index 1b9bdc775..f95617b95 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -30,6 +30,7 @@ API Reference Authenticator AzureAuth AzureStorageAuth + CopilotAuthenticator :py:mod:`pyrit.auxiliary_attacks` ================================= @@ -500,6 +501,7 @@ API Reference PromptTarget RealtimeTarget TextTarget + WebSocketCopilotTarget :py:mod:`pyrit.score` ===================== diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 813f680ab..1f273e26c 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -15,11 +15,13 @@ get_default_azure_scope, ) from pyrit.auth.azure_storage_auth import AzureStorageAuth +from pyrit.auth.copilot_authenticator import CopilotAuthenticator __all__ = [ "Authenticator", "AzureAuth", "AzureStorageAuth", + "CopilotAuthenticator", "TokenProviderCredential", "get_azure_token_provider", "get_azure_async_token_provider", diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py new file mode 100644 index 000000000..d5fa08c41 --- /dev/null +++ b/pyrit/auth/copilot_authenticator.py @@ -0,0 +1,418 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import logging +import os +from datetime import datetime, timedelta, timezone +from typing import Optional + +from msal_extensions import FilePersistence, build_encrypted_persistence + +from pyrit.auth.authenticator import Authenticator +from pyrit.common.path import PYRIT_CACHE_PATH + +logger = logging.getLogger(__name__) + + +class CopilotAuthenticator(Authenticator): + """ + Playwright-based authenticator for Microsoft Copilot. Used by WebSocketCopilotTarget. + + This authenticator automates browser login to obtain and refresh access tokens that are necessary + for accessing Microsoft Copilot via WebSocket connections. It uses Playwright to simulate user + interactions for authentication, and msal-extensions for encrypted token persistence. + + An access token acquired by this authenticator is usually valid for about 60 minutes. + + Note: + To be able to use this authenticator, you must set the following environment variables: + - COPILOT_USERNAME: your Microsoft account username (email) + - COPILOT_PASSWORD: your Microsoft account password + + Additionally, you need to have playwright installed and set up: + ``pip install playwright && playwright install chromium``. + """ + + #: Name of the cache file to store tokens + CACHE_FILE_NAME: str = "copilot_token_cache.bin" + #: Buffer before token expiry to avoid using tokens about to expire (in seconds) + EXPIRY_BUFFER_SECONDS: int = 300 + #: Default timeout for capturing token via network monitoring (in seconds) + DEFAULT_TOKEN_CAPTURE_TIMEOUT: int = 60 + #: Default timeout for waiting on page elements (in seconds) + DEFAULT_ELEMENT_TIMEOUT_SECONDS: int = 10 + #: Number of retries for network operations + DEFAULT_NETWORK_RETRIES: int = 3 + + def __init__( + self, + *, + headless: bool = False, + maximized: bool = True, + timeout_for_elements_seconds: int = DEFAULT_ELEMENT_TIMEOUT_SECONDS, + token_capture_timeout_seconds: int = DEFAULT_TOKEN_CAPTURE_TIMEOUT, + network_retries: int = DEFAULT_NETWORK_RETRIES, + fallback_to_plaintext: bool = False, + ): + """ + Initialize the CopilotAuthenticator. + + Args: + headless (bool): Whether to run the browser in headless mode. Default is False. + maximized (bool): Whether to start the browser maximized. Default is True. + timeout_for_elements_seconds (int): Timeout used when waiting for page elements, in seconds. + token_capture_timeout_seconds (int): Maximum time to wait for token capture via network monitoring. + network_retries (int): Number of retry attempts for network operations. Default is 3. + fallback_to_plaintext (bool): Whether to fallback to plaintext storage if encryption is unavailable. + If set to False (default), an exception will be raised if encryption cannot be used. + WARNING: Setting to True stores tokens in plaintext. + + Raises: + ValueError: If the required environment variables are not set. + """ + super().__init__() + + self._username = os.getenv("COPILOT_USERNAME") + self._password = os.getenv("COPILOT_PASSWORD") + + self._headless = headless + self._maximized = maximized + self._elements_timeout = timeout_for_elements_seconds * 1000 + self._token_capture_timeout = token_capture_timeout_seconds + self._network_retries = network_retries + self._fallback_to_plaintext = fallback_to_plaintext + + self._cache_dir = PYRIT_CACHE_PATH + self._cache_dir.mkdir(parents=True, exist_ok=True) + self._cache_file = str(self._cache_dir / self.CACHE_FILE_NAME) + + if not self._username or not self._password: + raise ValueError("COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set.") + + self._token_cache = self._create_persistent_cache(self._cache_file, self._fallback_to_plaintext) + self._current_claims = {} # for easy access to claims without re-decoding token + + # Lock to prevent concurrent token fetches from launching multiple browsers + self._token_fetch_lock = asyncio.Lock() + + async def refresh_token(self) -> str: + """ + Refresh the authentication token asynchronously. + + This will clear the existing token cache and fetch a new token with automated browser login. + + Returns: + str: The refreshed authentication token. + + Raises: + RuntimeError: If token refresh fails. + """ + logger.info("Refreshing access token...") + self._clear_token_cache() + self._current_claims = {} + token = await self._fetch_access_token_with_playwright() + + if not token: + raise RuntimeError("Failed to refresh access token.") + + return token + + async def get_token(self) -> str: + """ + Get the current authentication token. + + This checks the cache first and only launches the browser if no valid token is found. + If multiple calls are made concurrently, they will be serialized via an asyncio lock + to prevent launching multiple browser instances. + + Returns: + str: A valid Bearer token for Microsoft Copilot. + """ + async with self._token_fetch_lock: + cached_token = await self._get_cached_token_if_available_and_valid() + if cached_token and "access_token" in cached_token: + logger.info("Using cached access token.") + return cached_token["access_token"] + + logger.info("No valid cached token found. Initiating browser authentication.") + return await self.refresh_token() + + async def get_claims(self) -> dict: + """ + Get the JWT claims from the current authentication token. + + Returns: + dict: The JWT claims decoded from the access token. + """ + return self._current_claims or {} + + @staticmethod + def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = False): + """ + Create a persistent cache for token storage with encryption. + + Uses msal-extensions to provide encrypted storage. Falls back to plaintext + only if explicitly allowed and encryption is unavailable. + + Args: + cache_file: Path to the cache file. + fallback_to_plaintext: Whether to allow plaintext fallback. + + Returns: + A persistence object (encrypted or plaintext). + + Raises: + Exception: If encryption fails and fallback is not allowed. + """ + # https://github.com/AzureAD/microsoft-authentication-extensions-for-python + + try: + logger.info(f"Using encrypted persistent token cache: {cache_file}") + return build_encrypted_persistence(cache_file) + except Exception as e: + if fallback_to_plaintext: + logger.warning(f"Encryption unavailable ({e}). Falling back to PLAINTEXT storage.") + return FilePersistence(cache_file) + logger.error(f"Encryption unavailable ({e}) and fallback_to_plaintext is False. Cannot proceed.") + raise + + async def _get_cached_token_if_available_and_valid(self) -> Optional[dict]: + """ + Retrieve and validate cached token. + + Validates that: + - Token exists and is properly formatted. + - Token belongs to the current user (username match). + - Token has not expired (with safety buffer). + + Returns: + Token data dictionary if valid, None otherwise. + """ + try: + cache_data = self._token_cache.load() + if not cache_data: + logger.info("No cached token data found.") + return None + + token_data = json.loads(cache_data) + if "access_token" not in token_data: + logger.info("No access token in cache.") + return None + + cached_user = token_data.get("claims").get("upn") + if not cached_user: + logger.info("No user associated with cached token. Token invalidated.") + return None + elif cached_user != self._username: + logger.info( + f"Cached token is for different user (cached: {cached_user}, current: {self._username}). " + "Token invalidated." + ) + return None + + expires_at = token_data.get("expires_at") + if expires_at: + expiry_time = datetime.fromtimestamp(expires_at, tz=timezone.utc) + current_time = datetime.now(timezone.utc) + + # This should prevent most mid-request failures due to token expiration + expiry_with_buffer = expiry_time - timedelta(seconds=self.EXPIRY_BUFFER_SECONDS) + if current_time >= expiry_with_buffer: + minutes_until_expiry = (expiry_time - current_time).total_seconds() / 60 + logger.info( + f"Cached token expires in {minutes_until_expiry:.2f} minutes, " + f"within {self.EXPIRY_BUFFER_SECONDS}s safety buffer. Token invalidated." + ) + return None + + minutes_left = (expiry_time - current_time).total_seconds() / 60 + logger.info(f"Cached token is valid for another {minutes_left:.2f} minutes") + + return token_data + + except Exception as e: + error_name = type(e).__name__ + if "PersistenceNotFound" in error_name or "FileNotFoundError" in error_name: + logger.info("Cache file does not exist yet. Will be created on first token save.") + else: + logger.error(f"Failed to load cached token ({error_name}): {e}") + return None + + def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None: + """ + Save token to persistent cache with metadata. + + Args: + token: The access token to cache. + expires_in: Token lifetime in seconds (optional). + """ + self._current_claims = {} + + try: + import jwt + + self._current_claims = jwt.decode(token, algorithms=["RS256"], options={"verify_signature": False}) + + except Exception as e: + logger.error(f"Failed to decode token for caching: {e}") + + token_data = { + "access_token": token, + "token_type": "Bearer", + "claims": self._current_claims, + "cached_at": datetime.now(timezone.utc).timestamp(), + } + + if expires_in: + expires_at = datetime.now(timezone.utc).timestamp() + expires_in + token_data["expires_at"] = expires_at + token_data["expires_in"] = expires_in + + try: + self._token_cache.save(json.dumps(token_data)) + logger.info("Token successfully cached.") + except Exception as e: + logger.error(f"Failed to cache token: {e}") + + def _clear_token_cache(self) -> None: + try: + self._token_cache.save(json.dumps({})) + logger.info("Token cache cleared.") + except Exception as e: + logger.error(f"Failed to clear cache: {e}") + + async def _fetch_access_token_with_playwright(self) -> Optional[str]: + """ + Fetch access token using Playwright browser automation. + + Returns: + Optional[str]: The bearer token if successfully retrieved, None otherwise. + + Raises: + RuntimeError: If Playwright is not installed or browser launch fails. + """ + try: + from playwright.async_api import async_playwright + except ImportError: + raise RuntimeError( + "Playwright is not installed. Please install it with: " + "'pip install playwright && playwright install chromium'" + ) + + bearer_token = None + token_expires_in = None + + async with async_playwright() as playwright: + browser = None + context = None + + try: + logger.info(f"Launching browser for authentication (headless={self._headless})...") + browser = await playwright.chromium.launch( + headless=self._headless, args=["--start-maximized"] if self._maximized else [] + ) + + context = await browser.new_context(no_viewport=True) + page = await context.new_page() + + # response_handler >>> + async def response_handler(response): + nonlocal bearer_token, token_expires_in + + try: + url = response.url + + if "/oauth2/v2.0/token" in url: + try: + text = await response.text() + + if ( + '"token_type":"Bearer"' in text or '"tokenType":"Bearer"' in text + ) and "sydney" in text: + try: + data = json.loads(text) + if "access_token" in data: + bearer_token = data["access_token"] + token_expires_in = data.get("expires_in") + logger.info("Captured bearer token from JSON response.") + + except Exception as e: + logger.error(f"Error parsing JSON token response: {e}") + + except Exception as e: + logger.error(f"Error reading response: {e}") + + except Exception as e: + logger.error(f"Error handling response: {e}") + + # ^^^ response_handler + + page.on("response", response_handler) + + logger.info("Navigating to Office.com for authentication...") + await page.goto("https://www.office.com/") + + logger.info("Waiting for profile icon...") + await page.wait_for_selector("#mectrl_headerPicture", timeout=self._elements_timeout) + await page.click("#mectrl_headerPicture") + + logger.info("Waiting for email input...") + await page.wait_for_selector("#i0116", timeout=self._elements_timeout) + await page.fill("#i0116", self._username) + await page.click("#idSIButton9") + + logger.info("Waiting for password input...") + await page.wait_for_selector("#i0118", timeout=self._elements_timeout) + await page.fill("#i0118", self._password) + await page.click("#idSIButton9") + + logger.info("Waiting for 'Stay signed in?' prompt...") + await page.wait_for_selector("#idSIButton9", timeout=self._elements_timeout) + logger.info("Clicking 'Yes' to stay signed in...") + await page.click("#idSIButton9") + + logger.info("Successfully logged in.") + logger.info("Navigating to Copilot...") + + logger.info("Waiting for Copilot button and clicking it...") + await page.wait_for_selector('div[aria-label="M365 Copilot"]', timeout=self._elements_timeout) + await page.click('div[aria-label="M365 Copilot"]', timeout=self._elements_timeout) + + logger.info(f"Waiting up to {self._token_capture_timeout}s for bearer token to be captured...") + for elapsed in range(self._token_capture_timeout): + if bearer_token: + logger.info(f"Token captured after {elapsed}s") + break + await asyncio.sleep(1) + + if bearer_token: + logger.info( + f"Bearer token successfully retrieved. Preview: {bearer_token[:16]}...{bearer_token[-16:]}" + ) + self._save_token_to_cache(token=bearer_token, expires_in=token_expires_in) + else: + logger.error(f"Failed to retrieve bearer token within {self._token_capture_timeout} seconds.") + + return bearer_token + except Exception as e: + logger.error("Failed to retrieve access token using Playwright.") + + if "BrowserType.launch" in str(e): + logger.error("Playwright browser launch failed. Did you run 'playwright install chromium'?") + else: + # Sanitize error message to avoid leaking sensitive info + error_msg = str(e) + if self._password and self._password in error_msg: + error_msg = error_msg.replace(self._password, "******") + logger.error(f"Error details: {error_msg}") + + return None + finally: + logger.info("Gracefully closing Playwright browser instance...") + + if context: + await context.close() + if browser: + await browser.close() diff --git a/pyrit/common/path.py b/pyrit/common/path.py index 40340f28a..14158d3ff 100644 --- a/pyrit/common/path.py +++ b/pyrit/common/path.py @@ -41,6 +41,10 @@ def in_git_repo() -> bool: DB_DATA_PATH = get_default_data_path("dbdata") DB_DATA_PATH.mkdir(parents=True, exist_ok=True) +# Path to where cache files are stored, i.e. token cache, etc. +PYRIT_CACHE_PATH = get_default_data_path(".pyrit_cache") +PYRIT_CACHE_PATH.mkdir(parents=True, exist_ok=True) + # Path to where the logs are located LOG_PATH = pathlib.Path(DB_DATA_PATH, "logs.txt").resolve() LOG_PATH.touch(exist_ok=True) diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index cdbbdb0ff..eee3ff6cb 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -37,6 +37,7 @@ from pyrit.prompt_target.playwright_copilot_target import CopilotType, PlaywrightCopilotTarget from pyrit.prompt_target.prompt_shield_target import PromptShieldTarget from pyrit.prompt_target.text_target import TextTarget +from pyrit.prompt_target.websocket_copilot_target import WebSocketCopilotTarget __all__ = [ "AzureBlobStorageTarget", @@ -66,4 +67,5 @@ "PromptTarget", "RealtimeTarget", "TextTarget", + "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py new file mode 100644 index 000000000..c86165a65 --- /dev/null +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -0,0 +1,530 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import logging +import uuid +from enum import IntEnum +from typing import Optional + +import websockets + +from pyrit.auth import CopilotAuthenticator +from pyrit.exceptions import ( + EmptyResponseException, + pyrit_target_retry, +) +from pyrit.models import Message, construct_response_from_request +from pyrit.prompt_target import PromptTarget, limit_requests_per_minute + +logger = logging.getLogger(__name__) + +# Useful links: +# https://github.com/mbrg/power-pwn/blob/main/src/powerpwn/copilot/copilot_connector/copilot_connector.py +# https://labs.zenity.io/p/access-copilot-m365-terminal + + +class CopilotMessageType(IntEnum): + """Enumeration for Copilot WebSocket message types.""" + + UNKNOWN = -1 + PARTIAL_RESPONSE = 1 + FINAL_CONTENT = 2 + STREAM_END = 3 + USER_PROMPT = 4 + PING = 6 + + +class WebSocketCopilotTarget(PromptTarget): + """ + A WebSocket-based prompt target for integrating with Microsoft Copilot. + + This class facilitates communication with Microsoft Copilot over a WebSocket connection. + Authentication is handled automatically via `CopilotAuthenticator`, which uses Playwright + to automate browser login and obtain the required access tokens. + + Once authenticated, the target supports multi-turn conversations through server-side + state management. For each PyRIT conversation, it automatically generates consistent + `session_id` and `conversation_id` values, enabling Copilot to preserve conversational + context across multiple turns. + + Because conversation state is managed entirely on the Copilot server, this target does + not resend conversation history with each request and does not support programmatic + inspection or manipulation of that history. At present, there appears to be no supported + mechanism for modifying Copilot's server-side conversation state. + + Note: + This integration only works with licensed Microsoft 365 Copilot. + The free version of Copilot is not compatible. + + Important: + - Ensure the following environment variables are set: + - ``COPILOT_USERNAME`` - your Microsoft account username (email) + - ``COPILOT_PASSWORD`` - your Microsoft account password + + - Install `Playwright` and its browser dependencies: + ``pip install playwright && playwright install chromium`` + """ + + SUPPORTED_DATA_TYPES = {"text"} + + WEBSOCKET_BASE_URL: str = "wss://substrate.office.com/m365Copilot/Chathub" + RESPONSE_TIMEOUT_SECONDS: int = 60 + CONNECTION_TIMEOUT_SECONDS: int = 30 + + def __init__( + self, + *, + verbose: bool = False, + max_requests_per_minute: Optional[int] = None, + model_name: str = "copilot", + response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, + authenticator: Optional[CopilotAuthenticator] = None, + ) -> None: + """ + Initialize the WebSocketCopilotTarget. + + Args: + verbose (bool): Enable verbose logging. Defaults to False. + max_requests_per_minute (Optional[int]): Maximum number of requests per minute. + model_name (str): The model name. Defaults to "copilot". + response_timeout_seconds (int): Timeout for receiving responses in seconds. Defaults to 60s. + authenticator (Optional[CopilotAuthenticator]): Authenticator instance for token management. + If None, a new CopilotAuthenticator instance will be created with default settings. + + Raises: + ValueError: If ``response_timeout_seconds`` is not a positive integer. + """ + if response_timeout_seconds <= 0: + raise ValueError("response_timeout_seconds must be a positive integer.") + + self._authenticator = authenticator or CopilotAuthenticator() + self._response_timeout_seconds = response_timeout_seconds + + super().__init__( + verbose=verbose, + max_requests_per_minute=max_requests_per_minute, + endpoint=self.WEBSOCKET_BASE_URL, + model_name=model_name, + ) + + if self._verbose: + logger.info("WebSocketCopilotTarget initialized") + + @staticmethod + def _dict_to_websocket(data: dict) -> str: + """ + Convert a dictionary to WebSocket message format. + + SignalR protocol (used by Copilot) requires JSON messages terminated with + ASCII record separator (\\x1e). Minimal JSON formatting reduces bandwidth. + + https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/docs/specs/HubProtocol.md#json-encoding + + Args: + data (dict): The data to serialize. + + Returns: + str: JSON string with record separator appended. + """ + return json.dumps(data, separators=(",", ":")) + "\x1e" + + @staticmethod + def _parse_raw_message(message: str) -> list[tuple[CopilotMessageType, str]]: + """ + Extract actionable content from a raw WebSocket message. + Returns more than one JSON message if multiple are found. + + Args: + message (str): The raw WebSocket message string. + + Returns: + list[tuple[CopilotMessageType, str]]: A list of tuples where each tuple contains + message type and extracted content. + """ + results: list[tuple[CopilotMessageType, str]] = [] + messages = message.split("\x1e") # record separator + + for message in messages: + if not message or not message.strip(): + continue + + try: + data = json.loads(message) + msg_type = CopilotMessageType._value2member_map_.get(data.get("type", -1), CopilotMessageType.UNKNOWN) + + if msg_type in ( + CopilotMessageType.PING, + CopilotMessageType.PARTIAL_RESPONSE, + CopilotMessageType.STREAM_END, + ): + results.append((msg_type, "")) + continue + + if msg_type == CopilotMessageType.FINAL_CONTENT: + bot_text = data.get("item", {}).get("result", {}).get("message", "") + if not bot_text: + # In this case, EmptyResponseException will be raised anyway + logger.warning("FINAL_CONTENT received but no parseable content found.") + logger.debug(f"Full raw message: {message}") + results.append((CopilotMessageType.FINAL_CONTENT, bot_text)) + continue + + results.append((msg_type, "")) + + except json.JSONDecodeError as e: + logger.error(f"Failed to decode JSON message: {str(e)}") + results.append((CopilotMessageType.UNKNOWN, "")) + + return results if results else [(CopilotMessageType.UNKNOWN, "")] + + async def _build_websocket_url_async(self, *, session_id: str, copilot_conversation_id: str) -> str: + """ + Build the WebSocket URL with all the required authentication and session parameters. + + Returns: + str: Complete WebSocket URL with authentication and parameters. + + Raises: + ValueError: If token cannot be decoded or required claims (tid, oid) are missing. + """ + access_token = await self._authenticator.get_token() + token_claims = await self._authenticator.get_claims() + + tenant_id = token_claims.get("tid") + object_id = token_claims.get("oid") + + if not tenant_id or not object_id: + raise ValueError( + "Failed to extract tenant_id (tid) or object_id (oid) from bearer token. " + f"Token claims: {list(token_claims.keys())}" + ) + + client_request_id = str(uuid.uuid4()) + + base_url = f"{self.WEBSOCKET_BASE_URL}/{object_id}@{tenant_id}" + query_params = [ + f"ClientRequestId={client_request_id}", + f"X-SessionId={session_id}", + f"ConversationId={copilot_conversation_id}", + f"access_token={access_token}", + "X-variants=feature.includeExternal,feature.AssistantConnectorsContentSources," + "3S.BizChatWprBoostAssistant,3S.EnableMEFromSkillDiscovery,feature.EnableAuthErrorMessage," + "EnableRequestPlugins,feature.EnableSensitivityLabels,feature.IsEntityAnnotationsEnabled," + "EnableUnsupportedUrlDetector", + "source=%22officeweb%22", + "scenario=OfficeWebIncludedCopilot", + ] + + websocket_url = f"{base_url}?{'&'.join(query_params)}" + logger.debug(f"WebSocket URL: {websocket_url}") + return websocket_url + + def _build_prompt_message( + self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool + ) -> dict: + """ + Construct the prompt message payload for Copilot WebSocket API. + + Builds a comprehensive message structure following Copilot's expected format, + including session metadata, feature flags, and the user's prompt text. + + Returns: + dict: The complete message payload ready to be sent via WebSocket. + """ + request_id = trace_id = uuid.uuid4().hex + result = { + "arguments": [ + { + "source": "officeweb", + "clientCorrelationId": uuid.uuid4().hex, + "sessionId": session_id, + "optionsSets": [ + "enterprise_flux_web", + "enterprise_flux_work", + "enable_request_response_interstitials", + "enterprise_flux_image_v1", + "enterprise_toolbox_with_skdsstore", + "enterprise_toolbox_with_skdsstore_search_message_extensions", + "enable_ME_auth_interstitial", + "skdsstorethirdparty", + "enable_confirmation_interstitial", + "enable_plugin_auth_interstitial", + "enable_response_action_processing", + "enterprise_flux_work_gptv", + "enterprise_flux_work_code_interpreter", + "enable_batch_token_processing", + ], + "options": {}, + "allowedMessageTypes": [ + "Chat", + "Suggestion", + "InternalSearchQuery", + "InternalSearchResult", + "Disengaged", + "InternalLoaderMessage", + "RenderCardRequest", + "AdsQuery", + "SemanticSerp", + "GenerateContentQuery", + "SearchQuery", + "ConfirmationCard", + "AuthError", + "DeveloperLogs", + ], + "sliceIds": [], + "threadLevelGptId": {}, + "conversationId": copilot_conversation_id, + "traceId": trace_id, + "isStartOfSession": is_start_of_session, + "productThreadType": "Office", + "clientInfo": {"clientPlatform": "web"}, + "message": { + "author": "user", + "inputMethod": "Keyboard", + "text": prompt, + "entityAnnotationTypes": ["People", "File", "Event", "Email", "TeamsMessage"], + "requestId": request_id, + "locationInfo": {"timeZoneOffset": 0, "timeZone": "UTC"}, + "locale": "en-US", + "messageType": "Chat", + "experienceType": "Default", + }, + "plugins": [], + } + ], + "invocationId": "0", + "target": "chat", + "type": CopilotMessageType.USER_PROMPT, + } + + logger.debug(f"Built prompt message: {result}") + return result + + async def _connect_and_send( + self, *, prompt: str, session_id: str, copilot_conversation_id: str, is_start_of_session: bool + ) -> str: + """ + Establish WebSocket connection, send prompt, and await response. + + The method polls for messages, ignoring PARTIAL_RESPONSE streaming updates + until it receives either FINAL_CONTENT (success), STREAM_END, or UNKNOWN (error). + + Args: + prompt (str): The user prompt text to send. + session_id (str): Copilot session identifier. + copilot_conversation_id (str): Copilot conversation identifier. + is_start_of_session (bool): Whether this is the first message in the conversation. + + Returns: + str: The final response text from Copilot. + + Raises: + TimeoutError: If no response received within the specified timeout period. + RuntimeError: If WebSocket connection closes unexpectedly, protocol violation occurs, + or maximum message iterations exceeded. + """ + websocket_url = await self._build_websocket_url_async( + session_id=session_id, copilot_conversation_id=copilot_conversation_id + ) + + inputs = [ + {"protocol": "json", "version": 1}, # the handshake message, we expect PING in response + self._build_prompt_message( # the actual user prompt, we expect FINAL_CONTENT in response + prompt=prompt, + session_id=session_id, + copilot_conversation_id=copilot_conversation_id, + is_start_of_session=is_start_of_session, + ), + ] + response = "" + + async with websockets.connect( + websocket_url, + open_timeout=self.CONNECTION_TIMEOUT_SECONDS, + close_timeout=self.CONNECTION_TIMEOUT_SECONDS, + ) as websocket: + for input_msg in inputs: + payload = self._dict_to_websocket(input_msg) + await websocket.send(payload) + + is_user_input = input_msg.get("type") == CopilotMessageType.USER_PROMPT + + MAX_MESSAGE_ITERATIONS = 1000 + iteration_count = 0 + stop_polling = False + + while not stop_polling: + # Prevent infinite loops (e.g. if Copilot somehow never sends a terminating message) + iteration_count += 1 + if iteration_count > MAX_MESSAGE_ITERATIONS: + raise RuntimeError( + f"Exceeded maximum message iterations ({MAX_MESSAGE_ITERATIONS}) " + "while waiting for Copilot response." + ) + + try: + raw_message = await asyncio.wait_for( + websocket.recv(), + timeout=self._response_timeout_seconds, + ) + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out waiting for Copilot response after {self._response_timeout_seconds} seconds." + ) + + if raw_message is None: + raise RuntimeError( + "WebSocket connection closed unexpectedly: received None from websocket.recv()" + ) + + parsed_messages = self._parse_raw_message(raw_message) + + for msg_type, content in parsed_messages: + if msg_type in ( + CopilotMessageType.UNKNOWN, + CopilotMessageType.FINAL_CONTENT, + CopilotMessageType.STREAM_END, + ): + stop_polling = True + # Not breaking here to process all messages in this batch, + # possibly including FINAL_CONTENT + + if msg_type == CopilotMessageType.FINAL_CONTENT: + response = content + elif msg_type == CopilotMessageType.UNKNOWN: + logger.debug("Received unknown or empty message type.") + + # PING is Copilot's acknowledgment of the protocol handshake (first message in inputs[]) + # It should arrive after the handshake, not after a user prompt + # If we're processing a user prompt and receive PING, something is wrong - ignore it + # and keep polling for the actual FINAL_CONTENT response + elif msg_type == CopilotMessageType.PING and not is_user_input: + stop_polling = True + + return response + + def _validate_request(self, *, message: Message) -> None: + """ + Validate that the message meets target requirements. + + Args: + message (Message): The message to validate. + + Raises: + ValueError: If message contains more than one piece or non-text content. + """ + n_pieces = len(message.message_pieces) + if n_pieces != 1: + raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") + + piece_type = message.message_pieces[0].converted_value_data_type + if piece_type != "text": + raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") + + def _is_start_of_session(self, *, conversation_id: str) -> bool: + """ + Determine if this is the first message in a PyRIT conversation. + + Checks memory for existing conversation history to set the appropriate + flag for Copilot's server-side conversation initialization. + + Args: + conversation_id (str): The PyRIT conversation ID. + + Returns: + bool: True if no prior messages exist in this conversation, False otherwise. + """ + conversation_history = self._memory.get_conversation(conversation_id=conversation_id) + return len(conversation_history) == 0 + + def _generate_consistent_copilot_ids(self, *, pyrit_conversation_id: str) -> tuple[str, str]: + """ + Generate consistent Copilot session_id and conversation_id for a PyRIT conversation. + + This uses a deterministic approach to ensure that the same PyRIT conversation_id + always maps to the same Copilot session identifiers. This enables multi-turn + conversations while keeping the target stateless. + + Args: + pyrit_conversation_id (str): The PyRIT conversation ID from the Message. + + Returns: + tuple[str, str]: A tuple of (session_id, copilot_conversation_id). + """ + namespace = uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8") # DNS namespace UUID + session_id = str(uuid.uuid5(namespace, f"session_{pyrit_conversation_id}")) + copilot_conversation_id = str(uuid.uuid5(namespace, f"copilot_{pyrit_conversation_id}")) + + return session_id, copilot_conversation_id + + @limit_requests_per_minute + @pyrit_target_retry + async def send_prompt_async(self, *, message: Message) -> list[Message]: + """ + Asynchronously send a message to Microsoft Copilot using WebSocket. + + This method enables multi-turn conversations by using consistent session and conversation + identifiers derived from the PyRIT conversation_id. The Copilot API maintains conversation + state server-side, so only the current message is sent (no explicit history required). + + Args: + message (Message): A message to be sent to the target. + + Returns: + list[Message]: A list containing the response from Copilot. + + Raises: + EmptyResponseException: If the response from Copilot is empty. + InvalidStatus: If the WebSocket handshake fails with an HTTP status error. + RuntimeError: If any other error occurs during WebSocket communication. + """ + self._validate_request(message=message) + request_piece = message.message_pieces[0] + + pyrit_conversation_id = request_piece.conversation_id + is_start_of_session = self._is_start_of_session(conversation_id=pyrit_conversation_id) + + session_id, copilot_conversation_id = self._generate_consistent_copilot_ids( + pyrit_conversation_id=pyrit_conversation_id + ) + + logger.info( + f"Sending prompt to WebSocketCopilotTarget: {request_piece.converted_value} " + f"(conversation_id={pyrit_conversation_id}, is_start={is_start_of_session})" + ) + + try: + prompt_text = request_piece.converted_value + response_text = await self._connect_and_send( + prompt=prompt_text, + session_id=session_id, + copilot_conversation_id=copilot_conversation_id, + is_start_of_session=is_start_of_session, + ) + + if not response_text or not response_text.strip(): + logger.error("Empty response received from Copilot.") + raise EmptyResponseException(message="Copilot returned an empty response.") + logger.info(f"Received response from WebSocketCopilotTarget (length: {len(response_text)} chars)") + + response_entry = construct_response_from_request( + request=request_piece, response_text_pieces=[response_text] + ) + + return [response_entry] + + except websockets.exceptions.InvalidStatus as e: + logger.error( + f"WebSocket connection failed: {str(e)}\n" + "Ensure that COPILOT_USERNAME and COPILOT_PASSWORD environment variables are set correctly." + " For more details about authentication, refer to the class documentation." + ) + raise + + except EmptyResponseException: + raise + + except Exception as e: + raise RuntimeError(f"An error occurred during WebSocket communication: {str(e)}") from e diff --git a/tests/unit/auth/__init__.py b/tests/unit/auth/__init__.py new file mode 100644 index 000000000..9a0454564 --- /dev/null +++ b/tests/unit/auth/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/auth/test_copilot_authenticator.py b/tests/unit/auth/test_copilot_authenticator.py new file mode 100644 index 000000000..b4252ade8 --- /dev/null +++ b/tests/unit/auth/test_copilot_authenticator.py @@ -0,0 +1,853 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import os +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.auth.copilot_authenticator import CopilotAuthenticator + + +@pytest.fixture +def mock_env_vars(): + """Mock required environment variables.""" + + with patch.dict( + os.environ, + { + "COPILOT_USERNAME": "test@example.com", + "COPILOT_PASSWORD": "test_password_123", + }, + ): + yield + + +@pytest.fixture +def mock_persistent_cache(): + """Mock msal-extensions persistence.""" + + mock_cache = MagicMock() + mock_cache.load.return_value = None + mock_cache.save.return_value = None + return mock_cache + + +@pytest.fixture +def temp_cache_dir(): + """Create a temporary directory for cache files.""" + + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +class TestCopilotAuthenticatorConstants: + """Test class-level constants.""" + + def test_class_constants_defined(self): + assert hasattr(CopilotAuthenticator, "CACHE_FILE_NAME") + assert hasattr(CopilotAuthenticator, "EXPIRY_BUFFER_SECONDS") + assert hasattr(CopilotAuthenticator, "DEFAULT_TOKEN_CAPTURE_TIMEOUT") + assert hasattr(CopilotAuthenticator, "DEFAULT_ELEMENT_TIMEOUT_SECONDS") + assert hasattr(CopilotAuthenticator, "DEFAULT_NETWORK_RETRIES") + + def test_constant_values(self): + assert CopilotAuthenticator.CACHE_FILE_NAME == "copilot_token_cache.bin" + assert CopilotAuthenticator.EXPIRY_BUFFER_SECONDS == 300 + assert CopilotAuthenticator.DEFAULT_TOKEN_CAPTURE_TIMEOUT == 60 + assert CopilotAuthenticator.DEFAULT_ELEMENT_TIMEOUT_SECONDS == 10 + assert CopilotAuthenticator.DEFAULT_NETWORK_RETRIES == 3 + + +class TestCopilotAuthenticatorInitialization: + """Test CopilotAuthenticator initialization scenarios.""" + + def test_init_with_required_env_vars(self, mock_env_vars, mock_persistent_cache): + """Test successful initialization with required environment variables.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + + assert authenticator._username == "test@example.com" + assert authenticator._password == "test_password_123" + assert authenticator._headless is False + assert authenticator._maximized is True + assert authenticator._elements_timeout == 10000 + assert authenticator._token_capture_timeout == 60 + assert authenticator._network_retries == 3 + assert authenticator._fallback_to_plaintext is False + + assert isinstance(authenticator._token_fetch_lock, asyncio.Lock) + assert authenticator._current_claims == {} + assert authenticator._token_cache is not None + + def test_init_with_custom_parameters(self, mock_env_vars, mock_persistent_cache): + """Test initialization with custom parameters.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator( + headless=True, + maximized=False, + timeout_for_elements_seconds=20, + token_capture_timeout_seconds=120, + network_retries=5, + fallback_to_plaintext=True, + ) + + assert authenticator._headless is True + assert authenticator._maximized is False + assert authenticator._elements_timeout == 20000 + assert authenticator._token_capture_timeout == 120 + assert authenticator._network_retries == 5 + assert authenticator._fallback_to_plaintext is True + + def test_init_missing_env_var_raises_error(self): + """Test that missing a required environment variable raises ValueError.""" + + for missing_var in ["COPILOT_USERNAME", "COPILOT_PASSWORD"]: + env_vars = { + "COPILOT_USERNAME": "test@example.com", + "COPILOT_PASSWORD": "test_password_123", + } + env_vars.pop(missing_var) + with patch.dict(os.environ, env_vars, clear=True): + with pytest.raises( + ValueError, match="COPILOT_USERNAME and COPILOT_PASSWORD environment variables must be set" + ): + CopilotAuthenticator() + + def test_init_creates_cache_directory(self, mock_env_vars, mock_persistent_cache, temp_cache_dir): + """Test that initialization creates cache directory if it doesn't exist.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("pyrit.auth.copilot_authenticator.PYRIT_CACHE_PATH", temp_cache_dir / "new_cache"), + ): + authenticator = CopilotAuthenticator() + + assert (temp_cache_dir / "new_cache").exists() + assert authenticator._cache_dir == temp_cache_dir / "new_cache" + + +class TestCopilotAuthenticatorCacheManagement: + """Test cache creation, loading, and saving functionality.""" + + def test_create_persistent_cache_with_encryption(self): + """Test cache creation with encryption enabled.""" + + mock_encrypted_cache = MagicMock() + + with patch( + "pyrit.auth.copilot_authenticator.build_encrypted_persistence", + return_value=mock_encrypted_cache, + ) as mock_build: + result = CopilotAuthenticator._create_persistent_cache("/test/cache.bin", fallback_to_plaintext=False) + + mock_build.assert_called_once_with("/test/cache.bin") + assert result == mock_encrypted_cache + + def test_create_persistent_cache_fallback_to_plaintext(self): + """Test cache creation falls back to plaintext when encryption fails.""" + + mock_plaintext_cache = MagicMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.build_encrypted_persistence", + side_effect=Exception("Encryption not available"), + ), + patch( + "pyrit.auth.copilot_authenticator.FilePersistence", + return_value=mock_plaintext_cache, + ) as mock_file_persistence, + ): + result = CopilotAuthenticator._create_persistent_cache("/test/cache.bin", fallback_to_plaintext=True) + + mock_file_persistence.assert_called_once_with("/test/cache.bin") + assert result == mock_plaintext_cache + + def test_create_persistent_cache_raises_on_encryption_failure_without_fallback(self): + """Test cache creation raises exception when encryption fails and no fallback.""" + + with patch( + "pyrit.auth.copilot_authenticator.build_encrypted_persistence", + side_effect=Exception("Encryption not available"), + ): + with pytest.raises(Exception, match="Encryption not available"): + CopilotAuthenticator._create_persistent_cache("/test/cache.bin", fallback_to_plaintext=False) + + def test_save_token_to_cache_with_expiry(self, mock_env_vars, mock_persistent_cache): + """Test saving token to cache with expiration time.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("jwt.decode") as mock_jwt_decode, + ): + mock_claims = {"upn": "test@example.com", "aud": "sydney"} + mock_jwt_decode.return_value = mock_claims + + authenticator = CopilotAuthenticator() + test_token = "test.jwt.token" + + authenticator._save_token_to_cache(token=test_token, expires_in=3600) + + assert mock_persistent_cache.save.called + saved_data = json.loads(mock_persistent_cache.save.call_args[0][0]) + + assert saved_data["access_token"] == test_token + assert saved_data["token_type"] == "Bearer" + assert saved_data["claims"] == mock_claims + assert saved_data["expires_in"] == 3600 + assert "expires_at" in saved_data + assert "cached_at" in saved_data + assert authenticator._current_claims == mock_claims + + def test_save_token_to_cache_without_expiry(self, mock_env_vars, mock_persistent_cache): + """Test saving token to cache without expiration time.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("jwt.decode") as mock_jwt_decode, + ): + mock_claims = {"upn": "test@example.com"} + mock_jwt_decode.return_value = mock_claims + + authenticator = CopilotAuthenticator() + test_token = "test.jwt.token" + + authenticator._save_token_to_cache(token=test_token, expires_in=None) + + saved_data = json.loads(mock_persistent_cache.save.call_args[0][0]) + + assert "expires_in" not in saved_data + assert "expires_at" not in saved_data + assert saved_data["access_token"] == test_token + + def test_save_token_handles_jwt_decode_failure(self, mock_env_vars, mock_persistent_cache): + """Test that save_token handles JWT decode failures gracefully.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("jwt.decode", side_effect=Exception("Invalid JWT")), + patch("pyrit.auth.copilot_authenticator.logger") as mock_logger, + ): + authenticator = CopilotAuthenticator() + test_token = "invalid.jwt.token" + + authenticator._save_token_to_cache(token=test_token, expires_in=3600) + mock_logger.error.assert_called_with("Failed to decode token for caching: Invalid JWT") + + saved_data = json.loads(mock_persistent_cache.save.call_args[0][0]) + assert saved_data["access_token"] == test_token + assert saved_data["claims"] == {} + + def test_clear_token_cache(self, mock_env_vars, mock_persistent_cache): + """Test clearing the token cache.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + authenticator._clear_token_cache() + mock_persistent_cache.save.assert_called_with(json.dumps({})) + + def test_clear_token_cache_handles_error(self, mock_env_vars, mock_persistent_cache): + """Test that clear_token_cache handles errors gracefully.""" + + mock_persistent_cache.save.side_effect = Exception("Cache error") + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("pyrit.auth.copilot_authenticator.logger") as mock_logger, + ): + authenticator = CopilotAuthenticator() + authenticator._clear_token_cache() + mock_logger.error.assert_called_with("Failed to clear cache: Cache error") + + +class TestCopilotAuthenticatorCachedTokenRetrieval: + """Test cached token validation and retrieval.""" + + def test_get_cached_token_valid(self, mock_env_vars, mock_persistent_cache): + """Test retrieving valid cached token.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "cached.token.value", + "token_type": "Bearer", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + "cached_at": datetime.now(timezone.utc).timestamp(), + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is not None + assert result["access_token"] == "cached.token.value" + + def test_get_cached_token_expired(self, mock_env_vars, mock_persistent_cache): + """Test that expired token is not returned.""" + + expires_at = (datetime.now(timezone.utc) - timedelta(minutes=5)).timestamp() + cached_data = { + "access_token": "expired.token.value", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_within_expiry_buffer(self, mock_env_vars, mock_persistent_cache): + """Test that token within expiry buffer is not returned.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(seconds=200)).timestamp() + cached_data = { + "access_token": "soon.to.expire", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None # default buffer is 300 seconds, so should return None + + def test_get_cached_token_no_cache_file(self, mock_env_vars, mock_persistent_cache): + """Test behavior when cache file doesn't exist.""" + + mock_persistent_cache.load.return_value = None + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_wrong_user(self, mock_env_vars, mock_persistent_cache): + """Test that cached token for different user is invalidated.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "other.user.token", + "claims": {"upn": "different@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_no_upn_in_claims(self, mock_env_vars, mock_persistent_cache): + """Test that cached token without upn claim is invalidated.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "token.without.upn", + "claims": {"aud": "sydney"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_missing_access_token(self, mock_env_vars, mock_persistent_cache): + """Test that cache data without access_token is invalid.""" + + cached_data = { + "token_type": "Bearer", + "claims": {"upn": "test@example.com"}, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + def test_get_cached_token_invalid_json(self, mock_env_vars, mock_persistent_cache): + """Test handling of corrupted cache data.""" + + mock_persistent_cache.load.return_value = "invalid json {{" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + result = asyncio.run(authenticator._get_cached_token_if_available_and_valid()) + assert result is None + + +class TestCopilotAuthenticatorTokenRetrieval: + """Test token retrieval via get_token method.""" + + @pytest.mark.asyncio + async def test_get_token_uses_cached_token(self, mock_env_vars, mock_persistent_cache): + """Test that get_token uses cached token when available.""" + + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + cached_data = { + "access_token": "cached.valid.token", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + mock_persistent_cache.load.return_value = json.dumps(cached_data) + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + token = await authenticator.get_token() + assert token == "cached.valid.token" + + @pytest.mark.asyncio + async def test_get_token_fetches_new_when_no_cache(self, mock_env_vars, mock_persistent_cache): + """Test that get_token fetches new token when cache is empty.""" + + mock_persistent_cache.load.return_value = None + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="new.fetched.token", + ) as mock_fetch, + ): + authenticator = CopilotAuthenticator() + token = await authenticator.get_token() + mock_fetch.assert_called_once() + assert token == "new.fetched.token" + + @pytest.mark.asyncio + async def test_get_token_serializes_concurrent_requests(self, mock_env_vars, mock_persistent_cache): + """Test that concurrent get_token calls are serialized via lock.""" + + fetch_call_count = 0 + mock_persistent_cache.load.return_value = None # start with no cache + + async def mock_fetch(): + nonlocal fetch_call_count + fetch_call_count += 1 + await asyncio.sleep(0.01) # minimal delay to test concurrency + return f"token.{fetch_call_count}" + + def mock_load_side_effect(): + # After first fetch, return cached token for subsequent calls + if fetch_call_count > 0: + expires_at = (datetime.now(timezone.utc) + timedelta(hours=1)).timestamp() + return json.dumps( + { + "access_token": "token.1", + "claims": {"upn": "test@example.com"}, + "expires_at": expires_at, + } + ) + return None + + mock_persistent_cache.load.side_effect = mock_load_side_effect + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + side_effect=mock_fetch, + ), + patch("jwt.decode", return_value={"upn": "test@example.com"}), + ): + authenticator = CopilotAuthenticator() + + results = await asyncio.gather( + authenticator.get_token(), + authenticator.get_token(), + authenticator.get_token(), + ) + + # Only one fetch should have occurred due to lock + caching + assert fetch_call_count == 1 + assert results[0] == results[1] == results[2] == "token.1" + + +class TestCopilotAuthenticatorTokenRefresh: + """Test token refresh functionality.""" + + @pytest.mark.asyncio + async def test_refresh_token_clears_cache(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token clears existing cache.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="refreshed.token", + ), + ): + authenticator = CopilotAuthenticator() + await authenticator.refresh_token() + assert any(json.dumps({}) in str(call) for call in mock_persistent_cache.save.call_args_list) + + @pytest.mark.asyncio + async def test_refresh_token_fetches_new_token(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token fetches new token.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="refreshed.token", + ) as mock_fetch, + ): + authenticator = CopilotAuthenticator() + token = await authenticator.refresh_token() + mock_fetch.assert_called_once() + assert token == "refreshed.token" + + @pytest.mark.asyncio + async def test_refresh_token_raises_on_failure(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token raises exception when fetch fails.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value=None, + ), + ): + authenticator = CopilotAuthenticator() + with pytest.raises(RuntimeError, match="Failed to refresh access token"): + await authenticator.refresh_token() + + @pytest.mark.asyncio + async def test_refresh_token_clears_current_claims(self, mock_env_vars, mock_persistent_cache): + """Test that refresh_token clears current claims.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._fetch_access_token_with_playwright", + new_callable=AsyncMock, + return_value="refreshed.token", + ), + ): + authenticator = CopilotAuthenticator() + authenticator._current_claims = {"old": "claims"} + await authenticator.refresh_token() + assert authenticator._current_claims == {} + + +class TestCopilotAuthenticatorGetClaims: + """Test JWT claims retrieval.""" + + @pytest.mark.asyncio + async def test_get_claims_returns_current_claims(self, mock_env_vars, mock_persistent_cache): + """Test that get_claims returns current claims.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + test_claims = {"upn": "test@example.com", "aud": "sydney"} + authenticator._current_claims = test_claims + claims = await authenticator.get_claims() + assert claims == test_claims + + @pytest.mark.asyncio + async def test_get_claims_returns_empty_dict_when_no_claims(self, mock_env_vars, mock_persistent_cache): + """Test that get_claims returns empty dict when no claims set.""" + + with patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ): + authenticator = CopilotAuthenticator() + claims = await authenticator.get_claims() + assert claims == {} + + +class TestCopilotAuthenticatorPlaywrightIntegration: + """Test Playwright browser automation (mocked).""" + + @pytest.mark.asyncio + async def test_fetch_token_playwright_not_installed(self, mock_env_vars, mock_persistent_cache): + """Test that RuntimeError is raised when Playwright is not installed.""" + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch.dict("sys.modules", {"playwright.async_api": None}), + ): + authenticator = CopilotAuthenticator() + with pytest.raises(RuntimeError, match="Playwright is not installed"): + await authenticator._fetch_access_token_with_playwright() + + @pytest.mark.asyncio + async def test_fetch_token_with_playwright_success(self, mock_env_vars, mock_persistent_cache): + """Test successful token fetch with Playwright.""" + + # Setup mock Playwright objects + mock_page = AsyncMock() + mock_context = AsyncMock() + mock_browser = AsyncMock() + mock_playwright = AsyncMock() + + # Setup mock browser hierarchy + mock_playwright.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_browser.close = AsyncMock() + mock_context.close = AsyncMock() + + # Mock page methods + mock_page.goto = AsyncMock() + mock_page.wait_for_selector = AsyncMock() + mock_page.fill = AsyncMock() + + # Create a response that will trigger the token capture + mock_response = AsyncMock() + mock_response.url = "https://login.microsoftonline.com/oauth2/v2.0/token" + mock_response.text = AsyncMock( + return_value=( + '{"access_token":"captured.bearer.token","token_type":"Bearer","expires_in":3600,"resource":"sydney"}' + ) + ) + + response_handler = None + + def capture_handler(event, handler): + """Capture the response handler when page.on() is called.""" + nonlocal response_handler + if event == "response": + response_handler = handler + + mock_page.on = MagicMock(side_effect=capture_handler) + + async def trigger_response_on_click(*args, **kwargs): + """Trigger the response handler when click is called.""" + if response_handler: + await response_handler(mock_response) + + mock_page.click = AsyncMock(side_effect=trigger_response_on_click) + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + patch("jwt.decode", return_value={"upn": "test@example.com"}), + ): + authenticator = CopilotAuthenticator() + token = await authenticator._fetch_access_token_with_playwright() + + assert token == "captured.bearer.token" + mock_browser.close.assert_called_once() + mock_context.close.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_token_handles_browser_launch_failure(self, mock_env_vars, mock_persistent_cache): + """Test handling of browser launch failure.""" + + mock_playwright = AsyncMock() + mock_playwright.chromium.launch = AsyncMock(side_effect=Exception("BrowserType.launch failed")) + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + ): + authenticator = CopilotAuthenticator() + token = await authenticator._fetch_access_token_with_playwright() + assert token is None + + @pytest.mark.asyncio + async def test_fetch_token_sanitizes_password_in_errors(self, mock_env_vars, mock_persistent_cache): + """Test that password is sanitized in error messages.""" + + mock_playwright = AsyncMock() + error_with_password = Exception(f"Login failed with password: test_password_123") + mock_playwright.chromium.launch = AsyncMock(side_effect=error_with_password) + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + patch("pyrit.auth.copilot_authenticator.logger") as mock_logger, + ): + authenticator = CopilotAuthenticator() + await authenticator._fetch_access_token_with_playwright() + + # Verify password was sanitized in error log + logged_messages = [str(call) for call in mock_logger.error.call_args_list] + assert any("******" in msg for msg in logged_messages) + assert not any("test_password_123" in msg for msg in logged_messages) + + @pytest.mark.asyncio + async def test_fetch_token_timeout_waiting_for_token(self, mock_env_vars, mock_persistent_cache): + """Test timeout when waiting for token capture.""" + + mock_page = AsyncMock() + mock_context = AsyncMock() + mock_browser = AsyncMock() + mock_playwright = AsyncMock() + + mock_playwright.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_browser.close = AsyncMock() + mock_context.close = AsyncMock() + + mock_page.goto = AsyncMock() + mock_page.wait_for_selector = AsyncMock() + mock_page.click = AsyncMock() + mock_page.fill = AsyncMock() + mock_page.on = MagicMock() + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + patch("asyncio.sleep", new_callable=AsyncMock), # mock sleep to speed up test + ): + authenticator = CopilotAuthenticator(token_capture_timeout_seconds=1) + token = await authenticator._fetch_access_token_with_playwright() + assert token is None + mock_browser.close.assert_called_once() + + @pytest.mark.asyncio + async def test_fetch_token_closes_browser_on_exception(self, mock_env_vars, mock_persistent_cache): + """Test that browser is closed even when exception occurs.""" + + mock_page = AsyncMock() + mock_context = AsyncMock() + mock_browser = AsyncMock() + mock_playwright = AsyncMock() + + mock_playwright.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_context = AsyncMock(return_value=mock_context) + mock_context.new_page = AsyncMock(return_value=mock_page) + mock_browser.close = AsyncMock() + mock_context.close = AsyncMock() + + mock_page.goto = AsyncMock(side_effect=Exception("Navigation failed")) + mock_page.on = MagicMock() # page.on is synchronous in Playwright + + mock_async_playwright = AsyncMock() + mock_async_playwright.__aenter__ = AsyncMock(return_value=mock_playwright) + mock_async_playwright.__aexit__ = AsyncMock() + + with ( + patch( + "pyrit.auth.copilot_authenticator.CopilotAuthenticator._create_persistent_cache", + return_value=mock_persistent_cache, + ), + patch("playwright.async_api.async_playwright", return_value=mock_async_playwright), + ): + authenticator = CopilotAuthenticator() + token = await authenticator._fetch_access_token_with_playwright() + assert token is None + mock_context.close.assert_called_once() + mock_browser.close.assert_called_once() diff --git a/tests/unit/target/test_websocket_copilot_target.py b/tests/unit/target/test_websocket_copilot_target.py new file mode 100644 index 000000000..18356f1a1 --- /dev/null +++ b/tests/unit/target/test_websocket_copilot_target.py @@ -0,0 +1,542 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import jwt +import pytest + +from pyrit.auth import CopilotAuthenticator +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import WebSocketCopilotTarget +from pyrit.prompt_target.websocket_copilot_target import CopilotMessageType + + +@pytest.fixture +def mock_authenticator(): + token_payload = {"tid": "test_tenant_id", "oid": "test_object_id", "exp": 9999999999} + mock_token = jwt.encode(token_payload, "secret", algorithm="HS256") + if isinstance(mock_token, bytes): + mock_token = mock_token.decode("utf-8") + authenticator = MagicMock(spec=CopilotAuthenticator) + authenticator.get_token = AsyncMock(return_value=mock_token) + authenticator.get_claims = AsyncMock(return_value=token_payload) + return authenticator + + +@pytest.mark.usefixtures("patch_central_database") +class TestWebSocketCopilotTargetInit: + def test_init_with_default_parameters(self): + with patch("pyrit.prompt_target.websocket_copilot_target.CopilotAuthenticator") as mock_auth_class: + mock_auth_instance = MagicMock(spec=CopilotAuthenticator) + mock_auth_class.return_value = mock_auth_instance + + target = WebSocketCopilotTarget() + + mock_auth_class.assert_called_once() + assert target._authenticator == mock_auth_instance + assert target._response_timeout_seconds == WebSocketCopilotTarget.RESPONSE_TIMEOUT_SECONDS + assert target._model_name == "copilot" + assert target._endpoint == WebSocketCopilotTarget.WEBSOCKET_BASE_URL + assert target._verbose is False + assert target._max_requests_per_minute is None + + def test_init_with_custom_parameters(self, mock_authenticator): + target = WebSocketCopilotTarget( + authenticator=mock_authenticator, + verbose=True, + max_requests_per_minute=10, + model_name="custom_copilot", + response_timeout_seconds=120, + ) + + assert target._authenticator == mock_authenticator + assert target._response_timeout_seconds == 120 + assert target._model_name == "custom_copilot" + assert target._verbose is True + assert target._max_requests_per_minute == 10 + + def test_init_with_invalid_response_timeout(self, mock_authenticator): + for invalid_timeout in [0, -10, -1]: + with pytest.raises(ValueError, match="response_timeout_seconds must be a positive integer."): + WebSocketCopilotTarget(authenticator=mock_authenticator, response_timeout_seconds=invalid_timeout) + + +@pytest.mark.usefixtures("patch_central_database") +class TestDictToWebsocket: + @pytest.mark.parametrize( + "data,expected", + [ + ({"key": "value"}, '{"key":"value"}\x1e'), + ({"protocol": "json", "version": 1}, '{"protocol":"json","version":1}\x1e'), + ({"outer": {"inner": "value"}}, '{"outer":{"inner":"value"}}\x1e'), + ({"items": [1, 2, 3]}, '{"items":[1,2,3]}\x1e'), + ], + ) + def test_dict_to_websocket_converts_to_json_with_separator(self, data, expected): + result = WebSocketCopilotTarget._dict_to_websocket(data) + assert result == expected + + +@pytest.mark.usefixtures("patch_central_database") +class TestParseRawMessage: + @pytest.mark.parametrize( + "message,expected_types,expected_content", + [ + ("", [CopilotMessageType.UNKNOWN], [""]), + (" \n\t ", [CopilotMessageType.UNKNOWN], [""]), + ("{}\x1e", [CopilotMessageType.UNKNOWN], [""]), + ('{"type":6}\x1e', [CopilotMessageType.PING], [""]), + ( + '{"type":1,"target":"update","arguments":[{"messages":[{"text":"Partial","author":"bot"}]}]}\x1e', + [CopilotMessageType.PARTIAL_RESPONSE], + [""], + ), + ( + '{"type":2,"item":{"result":{"message":"Final."}}}\x1e{"type":3,"invocationId":"0"}\x1e', + [CopilotMessageType.FINAL_CONTENT, CopilotMessageType.STREAM_END], + ["Final.", ""], + ), + ( + '{"type":3,"invocationId":"0"}\x1e', + [CopilotMessageType.STREAM_END], + [""], + ), + ], + ) + def test_parse_raw_message_with_valid_data(self, message, expected_types, expected_content): + result = WebSocketCopilotTarget._parse_raw_message(message) + + assert len(result) == len(expected_types) + for i, expected_type in enumerate(expected_types): + assert result[i][0] == expected_type + assert result[i][1] == expected_content[i] + + def test_parse_final_message_without_content(self): + with patch("pyrit.prompt_target.websocket_copilot_target.logger") as mock_logger: + message = '{"type":2,"invocationId":"0"}\x1e' + result = WebSocketCopilotTarget._parse_raw_message(message) + + assert len(result) == 1 + assert result[0][0] == CopilotMessageType.FINAL_CONTENT + assert result[0][1] == "" + + mock_logger.warning.assert_called_with("FINAL_CONTENT received but no parseable content found.") + mock_logger.debug.assert_called_with(f"Full raw message: {message[:-1]}") + + @pytest.mark.parametrize( + "message", + [ + '{"type":99,"data":"unknown"}\x1e', + '{"data":"no type field"}\x1e', + '{"invalid json structure\x1e', + ], + ) + def test_parse_unknown_or_invalid_messages(self, message): + result = WebSocketCopilotTarget._parse_raw_message(message) + assert len(result) == 1 + assert result[0][0] == CopilotMessageType.UNKNOWN + assert result[0][1] == "" + + +@pytest.mark.usefixtures("patch_central_database") +class TestBuildWebsocketUrl: + @pytest.mark.asyncio + async def test_build_websocket_url_with_valid_token(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + session_id = "test_session_id" + copilot_conversation_id = "test_conversation_id" + + url = await target._build_websocket_url_async( + session_id=session_id, copilot_conversation_id=copilot_conversation_id + ) + expected_token = await mock_authenticator.get_token() + + assert url.startswith(f"{WebSocketCopilotTarget.WEBSOCKET_BASE_URL}/test_object_id@test_tenant_id?") + assert f"X-SessionId={session_id}" in url + assert f"ConversationId={copilot_conversation_id}" in url + assert f"access_token={expected_token}" in url + assert "ClientRequestId=" in url + assert "source=%22officeweb%22" in url + assert "scenario=OfficeWebIncludedCopilot" in url + + @pytest.mark.asyncio + async def test_build_websocket_url_with_missing_ids(self, mock_authenticator): + for missing_id in ["tid", "oid"]: + token_payload = {"tid": "test_tenant_id", "oid": "test_object_id", "exp": 9999999999} + del token_payload[missing_id] + + mock_token = jwt.encode(token_payload, "secret", algorithm="HS256") + if isinstance(mock_token, bytes): + mock_token = mock_token.decode("utf-8") + mock_authenticator.get_token = AsyncMock(return_value=mock_token) + mock_authenticator.get_claims = AsyncMock(return_value=token_payload) + + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + with pytest.raises(ValueError, match="Failed to extract tenant_id \\(tid\\) or object_id \\(oid\\)"): + await target._build_websocket_url_async(session_id="test", copilot_conversation_id="test") + + @pytest.mark.asyncio + async def test_build_websocket_url_with_invalid_token(self, mock_authenticator): + mock_authenticator.get_token = AsyncMock(return_value="invalid_token") + mock_authenticator.get_claims = AsyncMock(side_effect=ValueError("Failed to decode access token")) + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + with pytest.raises(ValueError, match="Failed to decode access token"): + await target._build_websocket_url_async(session_id="test", copilot_conversation_id="test") + + +@pytest.mark.usefixtures("patch_central_database") +class TestBuildPromptMessage: + def test_build_prompt_message_structure(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + prompt = "Hello Copilot" + session_id = "session_123" + copilot_conversation_id = "conv_456" + is_start_of_session = True + + message = target._build_prompt_message( + prompt=prompt, + session_id=session_id, + copilot_conversation_id=copilot_conversation_id, + is_start_of_session=is_start_of_session, + ) + + assert "arguments" in message + assert "invocationId" in message + assert "target" in message + assert "type" in message + + assert message["target"] == "chat" + assert message["type"] == CopilotMessageType.USER_PROMPT + assert message["invocationId"] == "0" + + args = message["arguments"][0] + assert args["sessionId"] == session_id + assert args["conversationId"] == copilot_conversation_id + assert args["isStartOfSession"] is True + assert args["source"] == "officeweb" + assert args["productThreadType"] == "Office" + + msg = args["message"] + assert msg["text"] == prompt + assert msg["author"] == "user" + assert msg["messageType"] == "Chat" + assert msg["locale"] == "en-US" + + def test_build_prompt_message_with_different_session_states(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message = target._build_prompt_message( + prompt="Follow-up question", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=False, + ) + + args = message["arguments"][0] + assert args["isStartOfSession"] is False + + +@pytest.mark.usefixtures("patch_central_database") +class TestConnectAndSend: + @pytest.mark.asyncio + async def test_connect_and_send_successful_response(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock( + side_effect=[ + '{"type":6}\x1e', # PING response to handshake + '{"type":1}\x1e', # partial response to user prompt + '{"type":2,"item":{"result":{"message":"Hello from Copilot"}}}\x1e', # final content + ] + ) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + response = await target._connect_and_send( + prompt="Hello", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=True, + ) + + assert response == "Hello from Copilot" + assert mock_websocket.send.call_count == 2 # handshake + user prompt + + @pytest.mark.asyncio + async def test_connect_and_send_timeout(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator, response_timeout_seconds=1) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(side_effect=asyncio.TimeoutError()) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + with pytest.raises(TimeoutError, match="Timed out waiting for Copilot response"): + await target._connect_and_send( + prompt="Hello", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=True, + ) + + @pytest.mark.asyncio + async def test_connect_and_send_none_response(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(return_value=None) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + with pytest.raises(RuntimeError, match="WebSocket connection closed unexpectedly"): + await target._connect_and_send( + prompt="Hello", + session_id="session_123", + copilot_conversation_id="conv_456", + is_start_of_session=True, + ) + + @pytest.mark.asyncio + async def test_connect_and_send_stream_end_without_final_content(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock( + side_effect=[ + '{"type":6}\x1e', + '{"type":3}\x1e', + ] + ) + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + response = await target._connect_and_send( + prompt="Hello", + session_id="sid", + copilot_conversation_id="cid", + is_start_of_session=True, + ) + + assert response == "" + + @pytest.mark.asyncio + async def test_connect_and_send_exceeds_max_iterations(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_websocket = AsyncMock() + mock_websocket.send = AsyncMock() + mock_websocket.recv = AsyncMock(return_value='{"type":1}\x1e') + mock_websocket.__aenter__ = AsyncMock(return_value=mock_websocket) + mock_websocket.__aexit__ = AsyncMock(return_value=None) + + with patch("websockets.connect", return_value=mock_websocket): + with pytest.raises(RuntimeError, match="Exceeded maximum message iterations"): + await target._connect_and_send( + prompt="Hello", + session_id="sid", + copilot_conversation_id="cid", + is_start_of_session=True, + ) + + +@pytest.mark.usefixtures("patch_central_database") +class TestValidateRequest: + def test_validate_request_with_single_text_piece(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="test", + converted_value="test", + conversation_id="123", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + # Should not raise + target._validate_request(message=message) + + def test_validate_request_with_multiple_pieces(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_pieces = [ + MessagePiece( + role="user", + original_value="test1", + converted_value="test1", + conversation_id="123", + original_value_data_type="text", + converted_value_data_type="text", + ), + MessagePiece( + role="user", + original_value="test2", + converted_value="test2", + conversation_id="123", + original_value_data_type="text", + converted_value_data_type="text", + ), + ] + message = Message(message_pieces=message_pieces) + + with pytest.raises(ValueError, match="This target only supports a single message piece"): + target._validate_request(message=message) + + def test_validate_request_with_non_text_type(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="image.png", + converted_value="image.png", + conversation_id="123", + original_value_data_type="image_path", + converted_value_data_type="image_path", + ) + message = Message(message_pieces=[message_piece]) + + with pytest.raises(ValueError, match="This target only supports text prompt input"): + target._validate_request(message=message) + + +@pytest.mark.usefixtures("patch_central_database") +class TestIsStartOfSession: + def test_is_start_of_session_with_empty_history(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + conversation_id = "test_conv_123" + result = target._is_start_of_session(conversation_id=conversation_id) + + assert result is True + mock_memory.get_conversation.assert_called_once_with(conversation_id=conversation_id) + + def test_is_start_of_session_with_existing_history(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + mock_memory = MagicMock() + mock_message = MagicMock() + mock_memory.get_conversation.return_value = [mock_message] + target._memory = mock_memory + + conversation_id = "test_conv_123" + result = target._is_start_of_session(conversation_id=conversation_id) + + assert result is False + + +@pytest.mark.usefixtures("patch_central_database") +class TestGenerateConsistentCopilotIds: + def test_generates_consistent_ids_for_same_conversation(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + pyrit_conv_id = "pyrit_conversation_123" + + session_id_1, copilot_conv_id_1 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id) + session_id_2, copilot_conv_id_2 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id) + + assert session_id_1 == session_id_2 + assert copilot_conv_id_1 == copilot_conv_id_2 + + def test_generates_different_ids_for_different_conversations(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + pyrit_conv_id_1 = "pyrit_conversation_123" + pyrit_conv_id_2 = "pyrit_conversation_456" + + session_id_1, copilot_conv_id_1 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id_1) + session_id_2, copilot_conv_id_2 = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id_2) + + assert session_id_1 != session_id_2 + assert copilot_conv_id_1 != copilot_conv_id_2 + + def test_generated_ids_are_valid_uuids(self, mock_authenticator): + """Test that generated IDs are valid UUID format and can be parsed.""" + import uuid + + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + pyrit_conv_id = "test_conversation" + session_id, copilot_conv_id = target._generate_consistent_copilot_ids(pyrit_conversation_id=pyrit_conv_id) + + # Should be parseable as UUIDs without raising exceptions + uuid.UUID(session_id) + uuid.UUID(copilot_conv_id) + + +@pytest.mark.usefixtures("patch_central_database") +class TestSendPromptAsync: + @pytest.mark.asyncio + async def test_send_prompt_async_successful(self, mock_authenticator): + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + conversation_id="conv_123", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_message_to_memory = AsyncMock() + target._memory = mock_memory + + with patch.object(target, "_connect_and_send", new=AsyncMock(return_value="Response from Copilot")): + responses = await target.send_prompt_async(message=message) + + assert len(responses) == 1 + assert responses[0].message_pieces[0].converted_value == "Response from Copilot" + assert responses[0].message_pieces[0].role == "assistant" + + @pytest.mark.asyncio + async def test_send_prompt_async_with_exceptions(self, mock_authenticator): + from pyrit.exceptions import EmptyResponseException + + target = WebSocketCopilotTarget(authenticator=mock_authenticator) + + message_piece = MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + conversation_id="conv_123", + original_value_data_type="text", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + target._memory = mock_memory + + # Test for various empty responses + for response in [None, "", " \n\t "]: + with patch.object(target, "_connect_and_send", new=AsyncMock(return_value=response)): + with pytest.raises(EmptyResponseException, match="Copilot returned an empty response"): + await target.send_prompt_async(message=message) + + # Test for generic exception during WebSocket communication + with patch.object(target, "_connect_and_send", new=AsyncMock(side_effect=Exception("Test error"))): + with pytest.raises(RuntimeError, match="An error occurred during WebSocket communication"): + await target.send_prompt_async(message=message) diff --git a/websocket_copilot_simple_example.py b/websocket_copilot_simple_example.py new file mode 100644 index 000000000..92d0097ec --- /dev/null +++ b/websocket_copilot_simple_example.py @@ -0,0 +1,53 @@ +""" +# TODO: add notebook example instead of this +""" + +import asyncio +import logging + +from pyrit.executor.attack import ConsoleAttackResultPrinter, MultiPromptSendingAttack +from pyrit.models import Message, MessagePiece +from pyrit.prompt_target import WebSocketCopilotTarget +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +async def main(): + await initialize_pyrit_async(memory_db_type=IN_MEMORY) + + target = WebSocketCopilotTarget() + + prompts = [ + "I'm thinking of a number between 1 and 10.", + "It's greater than 5.", + "It's an even number.", + "What number am I thinking of?", + ] + + messages = [] + for prompt in prompts: + message_piece = MessagePiece( + role="user", + original_value=prompt, + original_value_data_type="text", + converted_value_data_type="text", + ) + messages.append(Message(message_pieces=[message_piece])) + + attack = MultiPromptSendingAttack( + objective_target=target, + ) + + result = await attack.execute_async( + objective="Engage in a multi-turn conversation about a number guessing game", + messages=messages, + ) + + printer = ConsoleAttackResultPrinter() + await printer.print_conversation_async(result=result) # type: ignore + + +if __name__ == "__main__": + asyncio.run(main())