diff --git a/doc/code/converters/ansi_attack_converter.ipynb b/doc/code/converters/ansi_attack_converter.ipynb index cdd7dadf2..47a25e0cf 100644 --- a/doc/code/converters/ansi_attack_converter.ipynb +++ b/doc/code/converters/ansi_attack_converter.ipynb @@ -366,7 +366,6 @@ } ], "source": [ - "\n", "from pyrit.executor.attack import (\n", " AttackConverterConfig,\n", " AttackExecutor,\n", diff --git a/pyrit/executor/attack/component/__init__.py b/pyrit/executor/attack/component/__init__.py index 029b6dfe4..583ceb5b0 100644 --- a/pyrit/executor/attack/component/__init__.py +++ b/pyrit/executor/attack/component/__init__.py @@ -7,6 +7,7 @@ ConversationManager, ConversationState, format_conversation_context, + mark_messages_as_simulated, ) from pyrit.executor.attack.component.objective_evaluator import ObjectiveEvaluator from pyrit.executor.attack.component.simulated_conversation import ( @@ -19,6 +20,7 @@ "ConversationManager", "ConversationState", "format_conversation_context", + "mark_messages_as_simulated", "ObjectiveEvaluator", "generate_simulated_conversation_async", "SimulatedConversationResult", diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 2e9fc46ee..aca6eb793 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -4,7 +4,7 @@ import logging import uuid from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Sequence from pyrit.memory import CentralMemory from pyrit.models import ChatMessageRole, Message, MessagePiece, Score @@ -18,6 +18,29 @@ logger = logging.getLogger(__name__) +def mark_messages_as_simulated(messages: Sequence[Message]) -> List[Message]: + """ + Mark assistant messages as simulated_assistant for traceability. + + This function converts all assistant roles to simulated_assistant in the + provided messages. This is useful when loading conversations from YAML files + or other sources where the responses are not from actual targets. + + Args: + messages (Sequence[Message]): The messages to mark as simulated. + + Returns: + List[Message]: The same messages with assistant roles converted to simulated_assistant. + Modifies the messages in place and also returns them for convenience. + """ + result = list(messages) + for message in result: + for piece in message.message_pieces: + if piece._role == "assistant": + piece._role = "simulated_assistant" + return result + + def format_conversation_context(messages: List[Message]) -> str: """ Format a list of messages into a context string for adversarial chat system prompts. @@ -55,17 +78,22 @@ def format_conversation_context(messages: List[Message]) -> str: piece = message.get_piece() # Skip system messages - they're handled separately - if piece.role == "system": + if piece.api_role == "system": continue # Start a new turn when we see a user message - if piece.role == "user": + if piece.api_role == "user": turn_number += 1 context_parts.append(f"Turn {turn_number}:") # Format the piece content content = _format_piece_content(piece) - role_label = "User" if piece.role == "user" else "Assistant" + if piece.api_role == "user": + role_label = "User" + elif piece.is_simulated: + role_label = "Assistant (simulated)" + else: + role_label = "Assistant" context_parts.append(f"{role_label}: {content}") return "\n".join(context_parts) @@ -175,7 +203,7 @@ def get_last_message( if role: for m in reversed(conversation): piece = m.get_piece() - if piece.role == role: + if piece.api_role == role: return piece return None @@ -288,7 +316,7 @@ async def update_conversation_state_async( # Determine if we should exclude the last message (if it's a user message in multi-turn context) last_message = valid_requests[-1].message_pieces[0] is_multi_turn = max_turns is not None - should_exclude_last = is_multi_turn and last_message.role == "user" + should_exclude_last = is_multi_turn and last_message.api_role == "user" # Process all messages except potentially the last one for i, request in enumerate(valid_requests): @@ -351,9 +379,9 @@ async def _apply_role_specific_converters_async( for piece in request.message_pieces: applicable_converters: Optional[List[PromptConverterConfiguration]] = None - if piece.role == "user" and request_converters: + if piece.api_role == "user" and request_converters: applicable_converters = request_converters - elif piece.role == "assistant" and response_converters: + elif piece.api_role == "assistant" and response_converters: applicable_converters = response_converters # System messages get no converters (applicable_converters remains None) @@ -429,7 +457,7 @@ def _process_piece( is_multi_turn = max_turns is not None # Only assistant messages count as turns - if piece.role == "assistant" and is_multi_turn: + if piece.api_role == "assistant" and is_multi_turn: conversation_state.turn_count += 1 if conversation_state.turn_count > max_turns: @@ -466,11 +494,11 @@ async def _populate_conversation_state_async( return # Nothing to extract from empty history # Extract the last user message and assistant message scores from the last message - if last_message.role == "user": + if last_message.api_role == "user": conversation_state.last_user_message = last_message.converted_value logger.debug(f"Extracted last user message: {conversation_state.last_user_message[:50]}...") - elif last_message.role == "assistant": + elif last_message.api_role == "assistant": # Get scores for the last assistant message based off of the original id conversation_state.last_assistant_message_scores = list( self._memory.get_prompt_scores(prompt_ids=[str(last_message.original_prompt_id)]) @@ -482,7 +510,7 @@ async def _populate_conversation_state_async( return # Check assumption that there will be a user message preceding the assistant message - if len(prepended_conversation) > 1 and prepended_conversation[-2].get_piece().role == "user": + if len(prepended_conversation) > 1 and prepended_conversation[-2].get_piece().api_role == "user": conversation_state.last_user_message = prepended_conversation[-2].get_value() logger.debug(f"Extracted preceding user message: {conversation_state.last_user_message[:50]}...") else: @@ -533,11 +561,11 @@ async def prepend_to_adversarial_chat_async( for message in prepended_conversation: for piece in message.message_pieces: # Skip system messages - adversarial chat has its own system prompt - if piece.role == "system": + if piece.api_role == "system": continue # Create a new piece with swapped role for adversarial chat - swapped_role = role_swap.get(piece.role, piece.role) + swapped_role = role_swap.get(piece.api_role, piece.api_role) adversarial_piece = MessagePiece( id=uuid.uuid4(), diff --git a/pyrit/executor/attack/component/simulated_conversation.py b/pyrit/executor/attack/component/simulated_conversation.py index 5def320c7..6d9753908 100644 --- a/pyrit/executor/attack/component/simulated_conversation.py +++ b/pyrit/executor/attack/component/simulated_conversation.py @@ -63,7 +63,7 @@ def _effective_turn_index(self) -> int: # Calculate total complete turns (user+assistant pairs) total_turns = len(self.conversation) // 2 # Account for trailing user message (incomplete turn) - if len(self.conversation) % 2 == 1 and self.conversation[-1].role == "user": + if len(self.conversation) % 2 == 1 and self.conversation[-1].api_role == "user": total_turns += 1 if self.turn_index is None: @@ -104,7 +104,7 @@ def next_message(self) -> Optional[Message]: return None # User message for turn N is at index (N-1) * 2 user_idx = (turn - 1) * 2 - if user_idx < len(self.conversation) and self.conversation[user_idx].role == "user": + if user_idx < len(self.conversation) and self.conversation[user_idx].api_role == "user": return self.conversation[user_idx] return None @@ -225,9 +225,14 @@ async def generate_simulated_conversation_async( # Filter out system messages - prepended_conversation should only have user/assistant turns # System prompts are set separately on each target during attack execution + # Also mark assistant messages as simulated for traceability filtered_messages: List[Message] = [] for message in raw_messages: - if message.role != "system": + if message.api_role != "system": + # Mark assistant responses as simulated since this is a simulated conversation + if message.api_role == "assistant": + for piece in message.message_pieces: + piece._role = "simulated_assistant" filtered_messages.append(message) # Get the score from the result (there should be one score for the last turn) diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 4759b7fb4..4ffb6d535 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -777,7 +777,7 @@ async def _generate_subsequent_turn_prompt_async(self, objective: str) -> str: target_messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id) # Extract the last assistant response - assistant_responses = [r for r in target_messages if r.get_piece().role == "assistant"] + assistant_responses = [r for r in target_messages if r.get_piece().api_role == "assistant"] if not assistant_responses: logger.error(f"No assistant responses found in the conversation {self.objective_target_conversation_id}.") raise RuntimeError("Cannot proceed without an assistant response.") diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index c5997a30a..ea6902cc7 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -152,14 +152,14 @@ async def print_messages_async( turn_number = 0 for message in messages: # Increment turn number once per message with role="user" - if message.role == "user": + if message.api_role == "user": turn_number += 1 # User message header print() self._print_colored("─" * self._width, Fore.BLUE) self._print_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE) self._print_colored("─" * self._width, Fore.BLUE) - elif message.role == "system": + elif message.api_role == "system": # System message header (not counted as a turn) print() self._print_colored("─" * self._width, Fore.MAGENTA) @@ -169,7 +169,8 @@ async def print_messages_async( # Assistant or other role message header print() self._print_colored("─" * self._width, Fore.YELLOW) - self._print_colored(f"🔸 {message.role.upper()}", Style.BRIGHT, Fore.YELLOW) + role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper() + self._print_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW) self._print_colored("─" * self._width, Fore.YELLOW) # Now print all pieces in this message @@ -179,15 +180,15 @@ async def print_messages_async( continue # Handle converted values for user messages - if piece.role == "user" and piece.converted_value != piece.original_value: + if piece.api_role == "user" and piece.converted_value != piece.original_value: self._print_colored(f"{self._indent} Original:", Fore.CYAN) self._print_wrapped_text(piece.original_value, Fore.WHITE) print() self._print_colored(f"{self._indent} Converted:", Fore.CYAN) self._print_wrapped_text(piece.converted_value, Fore.WHITE) - elif piece.role == "user": + elif piece.api_role == "user": self._print_wrapped_text(piece.converted_value, Fore.BLUE) - elif piece.role == "system": + elif piece.api_role == "system": self._print_wrapped_text(piece.converted_value, Fore.MAGENTA) else: self._print_wrapped_text(piece.converted_value, Fore.YELLOW) diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index df4bfebc2..746c6118b 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -213,7 +213,7 @@ async def _get_conversation_markdown_async( if not message.message_pieces: continue - message_role = message.get_piece().role + message_role = message.get_piece().api_role if message_role == "system": markdown_lines.extend(self._format_system_message(message)) @@ -284,7 +284,8 @@ async def _format_assistant_message_async(self, *, message: Message) -> List[str List[str]: List of markdown strings representing the response message. """ lines = [] - role_name = message.message_pieces[0].role.capitalize() + piece = message.message_pieces[0] + role_name = "Assistant (Simulated)" if piece.is_simulated else piece.api_role.capitalize() lines.append(f"\n#### {role_name}\n") diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 3f07ddb28..276c0b02f 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -452,11 +452,11 @@ def _print_conversations(self, result: FuzzerResult) -> None: continue for message in target_messages: - if message.role == "user": + if message.api_role == "user": self._print_colored(f"{self._indent * 2} USER:", Style.BRIGHT, Fore.BLUE) self._print_wrapped_text(message.converted_value, Fore.BLUE) else: - self._print_colored(f"{self._indent * 2} {message.role.upper()}:", Style.BRIGHT, Fore.YELLOW) + self._print_colored(f"{self._indent * 2} {message.api_role.upper()}:", Style.BRIGHT, Fore.YELLOW) self._print_wrapped_text(message.converted_value, Fore.YELLOW) # Print scores if available diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 6773a5a72..83e5473b6 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -471,7 +471,7 @@ def get_request_from_response(self, *, response: Message) -> Message: Raises: ValueError: If the response is not from an assistant role or has no preceding request. """ - if response.role != "assistant": + if response.api_role != "assistant": raise ValueError("The provided request is not a response (role must be 'assistant').") if response.sequence < 1: raise ValueError("The provided request does not have a preceding request (sequence < 1).") @@ -628,7 +628,7 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> length_of_sequence_to_remove = 0 - if last_message.role == "system" or last_message.role == "user": + if last_message.api_role == "system" or last_message.api_role == "user": length_of_sequence_to_remove = 1 else: length_of_sequence_to_remove = 2 @@ -778,7 +778,7 @@ def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> Seq Sequence[ChatMessage]: The list of chat messages. """ memory_entries = self.get_message_pieces(conversation_id=conversation_id) - return [ChatMessage(role=me.role, content=me.converted_value) for me in memory_entries] # type: ignore + return [ChatMessage(role=me.api_role, content=me.converted_value) for me in memory_entries] # type: ignore def get_seeds( self, diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 2f3e928d0..67e664ac7 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -149,7 +149,9 @@ class PromptMemoryEntry(Base): __tablename__ = "PromptMemoryEntries" __table_args__ = {"extend_existing": True} id = mapped_column(CustomUUID, nullable=False, primary_key=True) - role: Mapped[Literal["system", "user", "assistant", "tool", "developer"]] = mapped_column(String, nullable=False) + role: Mapped[Literal["system", "user", "assistant", "simulated_assistant", "tool", "developer"]] = mapped_column( + String, nullable=False + ) conversation_id = mapped_column(String, nullable=False) sequence = mapped_column(INTEGER, nullable=False) timestamp = mapped_column(DateTime, nullable=False) @@ -192,7 +194,7 @@ def __init__(self, *, entry: MessagePiece): entry (MessagePiece): The message piece to convert into a database entry. """ self.id = entry.id - self.role = entry.role + self.role = entry._role self.conversation_id = entry.conversation_id self.sequence = entry.sequence self.timestamp = entry.timestamp diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index 341e38872..7c11b7907 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -7,7 +7,7 @@ from pyrit.models.literals import ChatMessageRole -ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "tool", "developer"] +ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "simulated_assistant", "tool", "developer"] class ToolCall(BaseModel): diff --git a/pyrit/models/literals.py b/pyrit/models/literals.py index ba08a1d3a..24b0f55d3 100644 --- a/pyrit/models/literals.py +++ b/pyrit/models/literals.py @@ -25,4 +25,4 @@ """ PromptResponseError = Literal["blocked", "none", "processing", "empty", "unknown"] -ChatMessageRole = Literal["system", "user", "assistant", "tool", "developer"] +ChatMessageRole = Literal["system", "user", "assistant", "simulated_assistant", "tool", "developer"] diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 8736c0c4b..73f9d0664 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -5,6 +5,7 @@ import copy import uuid +import warnings from datetime import datetime from typing import Dict, MutableSequence, Optional, Sequence, Union @@ -51,11 +52,44 @@ def get_piece(self, n: int = 0) -> MessagePiece: return self.message_pieces[n] @property - def role(self) -> ChatMessageRole: - """Return the role of the first request piece (they should all be the same).""" + def api_role(self) -> ChatMessageRole: + """ + Return the API-compatible role of the first message piece. + + Maps simulated_assistant to assistant for API compatibility. + All message pieces in a Message should have the same role. + """ if len(self.message_pieces) == 0: raise ValueError("Empty message pieces.") - return self.message_pieces[0].role + return self.message_pieces[0].api_role + + @property + def is_simulated(self) -> bool: + """ + Check if this is a simulated assistant response. + + Simulated responses come from prepended conversations or generated + simulated conversations, not from actual target responses. + """ + if len(self.message_pieces) == 0: + return False + return self.message_pieces[0].is_simulated + + @property + def role(self) -> ChatMessageRole: + """ + Deprecated: Use api_role for comparisons or _role for internal storage. + + This property is deprecated and will be removed in a future version. + Returns api_role for backward compatibility. + """ + warnings.warn( + "Message.role getter is deprecated. Use api_role for comparisons. " + "This property will be removed in 0.13.0.", + DeprecationWarning, + stacklevel=2, + ) + return self.api_role @property def conversation_id(self) -> str: @@ -98,7 +132,7 @@ def validate(self): conversation_id = self.message_pieces[0].conversation_id sequence = self.message_pieces[0].sequence - role = self.message_pieces[0].role + role = self.message_pieces[0]._role for message_piece in self.message_pieces: if message_piece.conversation_id != conversation_id: @@ -110,7 +144,7 @@ def validate(self): if message_piece.converted_value is None: raise ValueError("Converted prompt text is None.") - if message_piece.role != role: + if message_piece._role != role: raise ValueError("Inconsistent roles within the same message entry.") def __str__(self): diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index fe06a2491..7cf4d814d 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -4,6 +4,7 @@ from __future__ import annotations import uuid +import warnings from datetime import datetime from typing import Dict, List, Literal, Optional, Union, cast, get_args from uuid import uuid4 @@ -86,7 +87,7 @@ def __init__( if role not in ChatMessageRole.__args__: # type: ignore raise ValueError(f"Role {role} is not a valid role.") - self.role: ChatMessageRole = role + self._role: ChatMessageRole = role if converted_value is None: converted_value = original_value @@ -164,8 +165,76 @@ async def set_sha256_values_async(self): ) self.converted_value_sha256 = await converted_serializer.get_sha256() + @property + def api_role(self) -> ChatMessageRole: + """ + Role to use for API calls. + + Maps simulated_assistant to assistant for API compatibility. + Use this property when sending messages to external APIs. + """ + return "assistant" if self._role == "simulated_assistant" else self._role + + @property + def is_simulated(self) -> bool: + """ + Check if this is a simulated assistant response. + + Simulated responses come from prepended conversations or generated + simulated conversations, not from actual target responses. + """ + return self._role == "simulated_assistant" + + def get_role_for_storage(self) -> ChatMessageRole: + """ + Get the actual stored role, including simulated_assistant. + + Use this when duplicating messages or preserving role information + for storage. For API calls or comparisons, use api_role instead. + + Returns: + The actual role stored (may be simulated_assistant). + """ + return self._role + + @property + def role(self) -> ChatMessageRole: + """ + Deprecated: Use api_role for comparisons or _role for internal storage. + + This property is deprecated and will be removed in a future version. + Returns api_role for backward compatibility. + """ + warnings.warn( + "MessagePiece.role getter is deprecated. Use api_role for comparisons. " + "This property will be removed in 0.13.0.", + DeprecationWarning, + stacklevel=2, + ) + return self.api_role + + @role.setter + def role(self, value: ChatMessageRole) -> None: + """ + Set the role for this message piece. + + Args: + value: The role to set (system, user, assistant, simulated_assistant, tool, developer). + + Raises: + ValueError: If the role is not a valid ChatMessageRole. + """ + if value not in ChatMessageRole.__args__: # type: ignore + raise ValueError(f"Role {value} is not a valid role.") + self._role = value + def to_chat_message(self) -> ChatMessage: - return ChatMessage(role=cast(ChatMessageRole, self.role), content=self.converted_value) + """ + Convert to a ChatMessage for API calls. + + Uses api_role to ensure simulated_assistant is mapped to assistant. + """ + return ChatMessage(role=cast(ChatMessageRole, self.api_role), content=self.converted_value) def to_message(self) -> Message: # type: ignore # noqa F821 from pyrit.models.message import Message @@ -195,7 +264,7 @@ def set_piece_not_in_database(self): def to_dict(self) -> dict: return { "id": str(self.id), - "role": self.role, + "role": self._role, "conversation_id": self.conversation_id, "sequence": self.sequence, "timestamp": self.timestamp.isoformat() if self.timestamp else None, @@ -219,14 +288,14 @@ def to_dict(self) -> dict: } def __str__(self): - return f"{self.prompt_target_identifier}: {self.role}: {self.converted_value}" + return f"{self.prompt_target_identifier}: {self._role}: {self.converted_value}" __repr__ = __str__ def __eq__(self, other) -> bool: return ( self.id == other.id - and self.role == other.role + and self._role == other._role and self.original_value == other.original_value and self.original_value_data_type == other.original_value_data_type and self.original_value_sha256 == other.original_value_sha256 diff --git a/pyrit/models/seed_group.py b/pyrit/models/seed_group.py index 03a595ba5..6903eb7e9 100644 --- a/pyrit/models/seed_group.py +++ b/pyrit/models/seed_group.py @@ -312,8 +312,14 @@ def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> List[Message]: # Convert each prompt to a MessagePiece message_pieces = [] for prompt in sequence_prompts: + # Convert assistant to simulated_assistant for YAML-loaded conversations + # since these represent simulated/prepended content, not actual target responses + role = prompt.role or "user" + if role == "assistant": + role = "simulated_assistant" + piece = MessagePiece( - role=prompt.role or "user", + role=role, original_value=prompt.value, original_value_data_type=prompt.data_type or "text", prompt_target_identifier=None, diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f68e9fd54..852aaab9c 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -342,7 +342,7 @@ def _build_chat_messages_for_text(self, conversation: MutableSequence[Message]) if message_piece.converted_value_data_type != "text": raise ValueError("_build_chat_messages_for_text only supports text.") - chat_message = ChatMessage(role=message_piece.role, content=message_piece.converted_value) + chat_message = ChatMessage(role=message_piece.api_role, content=message_piece.converted_value) chat_messages.append(chat_message.model_dump(exclude_none=True)) return chat_messages @@ -368,7 +368,7 @@ async def _build_chat_messages_for_multi_modal_async(self, conversation: Mutable content = [] role = None for message_piece in message_pieces: - role = message_piece.role + role = message_piece.api_role if message_piece.converted_value_data_type == "text": entry = {"type": "text", "text": message_piece.converted_value} content.append(entry) diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index a0be70a3e..d3f3b5ddb 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -279,7 +279,7 @@ def _get_system_prompt_from_conversation(self, *, conversation_id: str) -> str: # Look for a system message at the beginning of the conversation if conversation and len(conversation) > 0: first_message = conversation[0] - if first_message.message_pieces and first_message.message_pieces[0].role == "system": + if first_message.message_pieces and first_message.message_pieces[0].api_role == "system": return first_message.message_pieces[0].converted_value # Return default system prompt if none found in conversation diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index eb4ce472b..d87aa9f6a 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -189,7 +189,7 @@ async def _construct_input_item_from_piece(self, piece: MessagePiece) -> Dict[st """ if piece.converted_value_data_type == "text": return { - "type": "input_text" if piece.role in ["developer", "user"] else "output_text", + "type": "input_text" if piece.api_role in ["developer", "user"] else "output_text", "text": piece.converted_value, } if piece.converted_value_data_type == "image_path": @@ -230,7 +230,7 @@ async def _build_input_for_multi_modal_async(self, conversation: MutableSequence ) # System message (remapped to developer) - if pieces[0].role == "system": + if pieces[0].api_role == "system": system_content = [] for piece in pieces: system_content.append({"type": "input_text", "text": piece.converted_value}) @@ -238,7 +238,7 @@ async def _build_input_for_multi_modal_async(self, conversation: MutableSequence continue # All pieces in a Message share the same role - role = pieces[0].role + role = pieces[0].api_role content: List[Dict[str, Any]] = [] for piece in pieces: @@ -644,7 +644,7 @@ def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any The tool-call section dict, or None if not found. """ for piece in reversed(reply.message_pieces): - if piece.role == "assistant": + if piece.api_role == "assistant": try: section = json.loads(piece.original_value) except Exception: diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index cfdc22ee1..eacd64787 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -70,8 +70,8 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non for conv_message in conversation: for piece in conv_message.message_pieces: # Only include user and assistant messages in the conversation text - if piece.role in ["user", "assistant", "tool"]: - role_display = piece.role.capitalize() + if piece.api_role in ["user", "assistant", "tool"]: + role_display = "Assistant (simulated)" if piece.is_simulated else piece.api_role.capitalize() conversation_text += f"{role_display}: {piece.converted_value}\n" # Create a new message with the concatenated conversation text @@ -80,7 +80,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non conversation_message = Message( message_pieces=[ MessagePiece( - role=original_piece.role, + role=original_piece.get_role_for_storage(), original_value=conversation_text, converted_value=conversation_text, id=original_piece.id, diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 8a8e311b0..8b050de8b 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -152,7 +152,9 @@ async def score_async( message (Message): The message to be scored. objective (Optional[str]): The task or objective based on which the message should be scored. Defaults to None. - role_filter (Optional[ChatMessageRole]): Only score messages with this role. Defaults to None. + role_filter (Optional[ChatMessageRole]): Only score messages with this exact stored role. + Use "assistant" to score only real assistant responses, or "simulated_assistant" + to score only simulated responses. Defaults to None (no filtering). skip_on_error_result (bool): If True, skip scoring if the message contains an error. Defaults to False. infer_objective_from_request (bool): If True, infer the objective from the message's previous request when objective is not provided. Defaults to False. @@ -166,7 +168,7 @@ async def score_async( """ self._validator.validate(message, objective=objective) - if role_filter is not None and message.role != role_filter: + if role_filter is not None and message.get_piece().get_role_for_storage() != role_filter: logger.debug("Skipping scoring due to role filter mismatch.") return [] @@ -638,7 +640,7 @@ def _extract_objective_from_response(self, response: Message) -> str: piece = response.get_piece() - if piece.role != "assistant": + if piece.api_role != "assistant": return "" conversation = self._memory.get_message_pieces(conversation_id=piece.conversation_id) @@ -672,7 +674,8 @@ async def score_response_async( response (Message): Response containing pieces to score. objective_scorer (Optional[Scorer]): The main scorer to determine success. Defaults to None. auxiliary_scorers (Optional[List[Scorer]]): List of auxiliary scorers to apply. Defaults to None. - role_filter (ChatMessageRole): Only score pieces with this role. Defaults to "assistant". + role_filter (ChatMessageRole): Only score pieces with this exact stored role. + Defaults to "assistant" (real responses only, not simulated). objective (Optional[str]): Task/objective for scoring context. Defaults to None. skip_on_error_result (bool): If True, skip scoring pieces that have errors. Defaults to True. @@ -748,7 +751,8 @@ async def score_response_multiple_scorers_async( Args: response (Message): The response containing pieces to score. scorers (List[Scorer]): List of scorers to apply. - role_filter (ChatMessageRole): Only score pieces with this role (default: "assistant"). + role_filter (ChatMessageRole): Only score pieces with this exact stored role. + Defaults to "assistant" (real responses only, not simulated). objective (Optional[str]): Optional objective description for scoring context. skip_on_error_result (bool): If True, skip scoring pieces that have errors (default: True). diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 1703e3195..b2add6322 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -100,7 +100,7 @@ async def _check_for_password_in_conversation(self, conversation_id: str) -> str conversation_as_text = "" for message in conversation: - conversation_as_text += "Gandalf" if message.message_pieces[0].role == "assistant" else "user" + conversation_as_text += "Gandalf" if message.message_pieces[0].api_role == "assistant" else "user" conversation_as_text += ": " conversation_as_text += message.get_value() conversation_as_text += "\n" diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index b4d14d819..a56c05cef 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -91,7 +91,7 @@ async def _score_frames_async(self, *, message_piece: MessagePiece, objective: O piece = MessagePiece( original_value=message_piece.converted_value, - role=message_piece.role, + role=message_piece.get_role_for_storage(), original_prompt_id=original_prompt_id, converted_value=frame, converted_value_data_type="image_path", diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index 73e32018e..dabbe68eb 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -852,3 +852,78 @@ def test_conversation_state_dataclass_with_values(self, sample_score: Score): assert state.turn_count == 5 assert state.last_user_message == "Test message" assert state.last_assistant_message_scores == [sample_score] + + +class TestMarkMessagesAsSimulated: + """Tests for the mark_messages_as_simulated helper function.""" + + def test_mark_messages_as_simulated_converts_assistant(self): + """Test that assistant role is converted to simulated_assistant.""" + from pyrit.executor.attack.component.conversation_manager import ( + mark_messages_as_simulated, + ) + + piece = MessagePiece(role="assistant", original_value="Hello", conversation_id="test") + message = Message(message_pieces=[piece]) + + result = mark_messages_as_simulated([message]) + + assert len(result) == 1 + assert result[0].message_pieces[0].get_role_for_storage() == "simulated_assistant" + assert result[0].message_pieces[0].api_role == "assistant" + assert result[0].message_pieces[0].is_simulated is True + + def test_mark_messages_as_simulated_leaves_user_unchanged(self): + """Test that user role is not changed.""" + from pyrit.executor.attack.component.conversation_manager import ( + mark_messages_as_simulated, + ) + + piece = MessagePiece(role="user", original_value="Hello", conversation_id="test") + message = Message(message_pieces=[piece]) + + result = mark_messages_as_simulated([message]) + + assert len(result) == 1 + assert result[0].message_pieces[0].get_role_for_storage() == "user" + assert result[0].message_pieces[0].is_simulated is False + + def test_mark_messages_as_simulated_leaves_system_unchanged(self): + """Test that system role is not changed.""" + from pyrit.executor.attack.component.conversation_manager import ( + mark_messages_as_simulated, + ) + + piece = MessagePiece(role="system", original_value="You are helpful", conversation_id="test") + message = Message(message_pieces=[piece]) + + result = mark_messages_as_simulated([message]) + + assert len(result) == 1 + assert result[0].message_pieces[0].get_role_for_storage() == "system" + assert result[0].message_pieces[0].is_simulated is False + + def test_mark_messages_as_simulated_mixed_conversation(self): + """Test marking a conversation with mixed roles.""" + from pyrit.executor.attack.component.conversation_manager import ( + mark_messages_as_simulated, + ) + + user_piece = MessagePiece(role="user", original_value="Hello", conversation_id="test", sequence=1) + assistant_piece = MessagePiece(role="assistant", original_value="Hi there", conversation_id="test", sequence=2) + + messages = [ + Message(message_pieces=[user_piece]), + Message(message_pieces=[assistant_piece]), + ] + + result = mark_messages_as_simulated(messages) + + assert len(result) == 2 + # User should be unchanged + assert result[0].message_pieces[0].get_role_for_storage() == "user" + assert result[0].is_simulated is False + # Assistant should be converted + assert result[1].message_pieces[0].get_role_for_storage() == "simulated_assistant" + assert result[1].is_simulated is True + assert result[1].api_role == "assistant" diff --git a/tests/unit/test_literals.py b/tests/unit/test_literals.py index 4cb1ec29a..a7790dee2 100644 --- a/tests/unit/test_literals.py +++ b/tests/unit/test_literals.py @@ -34,5 +34,5 @@ def test_prompt_response_error(): def test_chat_message_role(): assert get_origin(ChatMessageRole) is Literal - expected_literals = {"system", "user", "assistant", "tool", "developer"} + expected_literals = {"system", "user", "assistant", "simulated_assistant", "tool", "developer"} assert set(get_args(ChatMessageRole)) == expected_literals diff --git a/tests/unit/test_message.py b/tests/unit/test_message.py index 83343f313..dd6c40b6a 100644 --- a/tests/unit/test_message.py +++ b/tests/unit/test_message.py @@ -175,3 +175,37 @@ def test_from_prompt_with_empty_string(self) -> None: assert len(message.message_pieces) == 1 assert message.message_pieces[0].original_value == "" + + +class TestMessageSimulatedAssistantRole: + """Tests for Message simulated_assistant role properties.""" + + def test_api_role_returns_assistant_for_simulated_assistant(self) -> None: + """Test that Message.api_role returns 'assistant' for simulated_assistant.""" + piece = MessagePiece(role="simulated_assistant", original_value="Hello", conversation_id="test") + message = Message(message_pieces=[piece]) + assert message.api_role == "assistant" + + def test_api_role_returns_assistant_for_assistant(self) -> None: + """Test that Message.api_role returns 'assistant' for assistant.""" + piece = MessagePiece(role="assistant", original_value="Hello", conversation_id="test") + message = Message(message_pieces=[piece]) + assert message.api_role == "assistant" + + def test_is_simulated_true_for_simulated_assistant(self) -> None: + """Test that Message.is_simulated returns True for simulated_assistant.""" + piece = MessagePiece(role="simulated_assistant", original_value="Hello", conversation_id="test") + message = Message(message_pieces=[piece]) + assert message.is_simulated is True + + def test_is_simulated_false_for_assistant(self) -> None: + """Test that Message.is_simulated returns False for assistant.""" + piece = MessagePiece(role="assistant", original_value="Hello", conversation_id="test") + message = Message(message_pieces=[piece]) + assert message.is_simulated is False + + def test_is_simulated_false_for_empty_pieces(self) -> None: + """Test that Message.is_simulated returns False for empty pieces (via skip_validation).""" + message = Message(message_pieces=[MessagePiece(role="user", original_value="x", conversation_id="test")]) + message.message_pieces = [] # Manually empty for edge case test + assert message.is_simulated is False diff --git a/tests/unit/test_message_piece.py b/tests/unit/test_message_piece.py index 6fa1d9098..313333920 100644 --- a/tests/unit/test_message_piece.py +++ b/tests/unit/test_message_piece.py @@ -918,3 +918,65 @@ def test_message_piece_harm_categories_with_labels(): result = entry.to_dict() assert result["targeted_harm_categories"] == harm_categories assert result["labels"] == labels + + +class TestSimulatedAssistantRole: + """Tests for simulated_assistant role properties.""" + + def test_api_role_returns_assistant_for_assistant(self): + """Test that api_role returns 'assistant' for assistant role.""" + piece = MessagePiece(role="assistant", original_value="Hello") + assert piece.api_role == "assistant" + + def test_api_role_returns_assistant_for_simulated_assistant(self): + """Test that api_role returns 'assistant' for simulated_assistant role.""" + piece = MessagePiece(role="simulated_assistant", original_value="Hello") + assert piece.api_role == "assistant" + + def test_api_role_returns_user_for_user(self): + """Test that api_role returns 'user' for user role.""" + piece = MessagePiece(role="user", original_value="Hello") + assert piece.api_role == "user" + + def test_api_role_returns_system_for_system(self): + """Test that api_role returns 'system' for system role.""" + piece = MessagePiece(role="system", original_value="Hello") + assert piece.api_role == "system" + + def test_is_simulated_true_for_simulated_assistant(self): + """Test that is_simulated returns True for simulated_assistant.""" + piece = MessagePiece(role="simulated_assistant", original_value="Hello") + assert piece.is_simulated is True + + def test_is_simulated_false_for_assistant(self): + """Test that is_simulated returns False for assistant.""" + piece = MessagePiece(role="assistant", original_value="Hello") + assert piece.is_simulated is False + + def test_is_simulated_false_for_user(self): + """Test that is_simulated returns False for user.""" + piece = MessagePiece(role="user", original_value="Hello") + assert piece.is_simulated is False + + def test_get_role_for_storage_returns_simulated_assistant(self): + """Test that get_role_for_storage returns the actual stored role.""" + piece = MessagePiece(role="simulated_assistant", original_value="Hello") + assert piece.get_role_for_storage() == "simulated_assistant" + + def test_get_role_for_storage_returns_assistant(self): + """Test that get_role_for_storage returns assistant for assistant role.""" + piece = MessagePiece(role="assistant", original_value="Hello") + assert piece.get_role_for_storage() == "assistant" + + def test_get_role_for_storage_returns_user(self): + """Test that get_role_for_storage returns user for user role.""" + piece = MessagePiece(role="user", original_value="Hello") + assert piece.get_role_for_storage() == "user" + + def test_role_setter_sets_simulated_assistant(self): + """Test that role setter can set simulated_assistant.""" + piece = MessagePiece(role="assistant", original_value="Hello") + piece.role = "simulated_assistant" + assert piece.get_role_for_storage() == "simulated_assistant" + assert piece.api_role == "assistant" + assert piece.is_simulated is True