Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion doc/code/converters/ansi_attack_converter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@
}
],
"source": [
"\n",
"from pyrit.executor.attack import (\n",
" AttackConverterConfig,\n",
" AttackExecutor,\n",
Expand Down
2 changes: 2 additions & 0 deletions pyrit/executor/attack/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ConversationManager,
ConversationState,
format_conversation_context,
mark_messages_as_simulated,
)
from pyrit.executor.attack.component.objective_evaluator import ObjectiveEvaluator
from pyrit.executor.attack.component.simulated_conversation import (
Expand All @@ -19,6 +20,7 @@
"ConversationManager",
"ConversationState",
"format_conversation_context",
"mark_messages_as_simulated",
"ObjectiveEvaluator",
"generate_simulated_conversation_async",
"SimulatedConversationResult",
Expand Down
56 changes: 42 additions & 14 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import uuid
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Sequence

from pyrit.memory import CentralMemory
from pyrit.models import ChatMessageRole, Message, MessagePiece, Score
Expand All @@ -18,6 +18,29 @@
logger = logging.getLogger(__name__)


def mark_messages_as_simulated(messages: Sequence[Message]) -> List[Message]:
"""
Mark assistant messages as simulated_assistant for traceability.

This function converts all assistant roles to simulated_assistant in the
provided messages. This is useful when loading conversations from YAML files
or other sources where the responses are not from actual targets.

Args:
messages (Sequence[Message]): The messages to mark as simulated.

Returns:
List[Message]: The same messages with assistant roles converted to simulated_assistant.
Modifies the messages in place and also returns them for convenience.
"""
result = list(messages)
for message in result:
for piece in message.message_pieces:
if piece._role == "assistant":
piece._role = "simulated_assistant"
return result


def format_conversation_context(messages: List[Message]) -> str:
"""
Format a list of messages into a context string for adversarial chat system prompts.
Expand Down Expand Up @@ -55,17 +78,22 @@ def format_conversation_context(messages: List[Message]) -> str:
piece = message.get_piece()

# Skip system messages - they're handled separately
if piece.role == "system":
if piece.api_role == "system":
continue

# Start a new turn when we see a user message
if piece.role == "user":
if piece.api_role == "user":
turn_number += 1
context_parts.append(f"Turn {turn_number}:")

# Format the piece content
content = _format_piece_content(piece)
role_label = "User" if piece.role == "user" else "Assistant"
if piece.api_role == "user":
role_label = "User"
elif piece.is_simulated:
role_label = "Assistant (simulated)"
else:
role_label = "Assistant"
context_parts.append(f"{role_label}: {content}")

return "\n".join(context_parts)
Expand Down Expand Up @@ -175,7 +203,7 @@ def get_last_message(
if role:
for m in reversed(conversation):
piece = m.get_piece()
if piece.role == role:
if piece.api_role == role:
return piece
return None

Expand Down Expand Up @@ -288,7 +316,7 @@ async def update_conversation_state_async(
# Determine if we should exclude the last message (if it's a user message in multi-turn context)
last_message = valid_requests[-1].message_pieces[0]
is_multi_turn = max_turns is not None
should_exclude_last = is_multi_turn and last_message.role == "user"
should_exclude_last = is_multi_turn and last_message.api_role == "user"

# Process all messages except potentially the last one
for i, request in enumerate(valid_requests):
Expand Down Expand Up @@ -351,9 +379,9 @@ async def _apply_role_specific_converters_async(
for piece in request.message_pieces:
applicable_converters: Optional[List[PromptConverterConfiguration]] = None

if piece.role == "user" and request_converters:
if piece.api_role == "user" and request_converters:
applicable_converters = request_converters
elif piece.role == "assistant" and response_converters:
elif piece.api_role == "assistant" and response_converters:
applicable_converters = response_converters
# System messages get no converters (applicable_converters remains None)

Expand Down Expand Up @@ -429,7 +457,7 @@ def _process_piece(
is_multi_turn = max_turns is not None

# Only assistant messages count as turns
if piece.role == "assistant" and is_multi_turn:
if piece.api_role == "assistant" and is_multi_turn:
conversation_state.turn_count += 1

if conversation_state.turn_count > max_turns:
Expand Down Expand Up @@ -466,11 +494,11 @@ async def _populate_conversation_state_async(
return # Nothing to extract from empty history

# Extract the last user message and assistant message scores from the last message
if last_message.role == "user":
if last_message.api_role == "user":
conversation_state.last_user_message = last_message.converted_value
logger.debug(f"Extracted last user message: {conversation_state.last_user_message[:50]}...")

elif last_message.role == "assistant":
elif last_message.api_role == "assistant":
# Get scores for the last assistant message based off of the original id
conversation_state.last_assistant_message_scores = list(
self._memory.get_prompt_scores(prompt_ids=[str(last_message.original_prompt_id)])
Expand All @@ -482,7 +510,7 @@ async def _populate_conversation_state_async(
return

# Check assumption that there will be a user message preceding the assistant message
if len(prepended_conversation) > 1 and prepended_conversation[-2].get_piece().role == "user":
if len(prepended_conversation) > 1 and prepended_conversation[-2].get_piece().api_role == "user":
conversation_state.last_user_message = prepended_conversation[-2].get_value()
logger.debug(f"Extracted preceding user message: {conversation_state.last_user_message[:50]}...")
else:
Expand Down Expand Up @@ -533,11 +561,11 @@ async def prepend_to_adversarial_chat_async(
for message in prepended_conversation:
for piece in message.message_pieces:
# Skip system messages - adversarial chat has its own system prompt
if piece.role == "system":
if piece.api_role == "system":
continue

# Create a new piece with swapped role for adversarial chat
swapped_role = role_swap.get(piece.role, piece.role)
swapped_role = role_swap.get(piece.api_role, piece.api_role)

adversarial_piece = MessagePiece(
id=uuid.uuid4(),
Expand Down
11 changes: 8 additions & 3 deletions pyrit/executor/attack/component/simulated_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _effective_turn_index(self) -> int:
# Calculate total complete turns (user+assistant pairs)
total_turns = len(self.conversation) // 2
# Account for trailing user message (incomplete turn)
if len(self.conversation) % 2 == 1 and self.conversation[-1].role == "user":
if len(self.conversation) % 2 == 1 and self.conversation[-1].api_role == "user":
total_turns += 1

if self.turn_index is None:
Expand Down Expand Up @@ -104,7 +104,7 @@ def next_message(self) -> Optional[Message]:
return None
# User message for turn N is at index (N-1) * 2
user_idx = (turn - 1) * 2
if user_idx < len(self.conversation) and self.conversation[user_idx].role == "user":
if user_idx < len(self.conversation) and self.conversation[user_idx].api_role == "user":
return self.conversation[user_idx]
return None

Expand Down Expand Up @@ -225,9 +225,14 @@ async def generate_simulated_conversation_async(

# Filter out system messages - prepended_conversation should only have user/assistant turns
# System prompts are set separately on each target during attack execution
# Also mark assistant messages as simulated for traceability
filtered_messages: List[Message] = []
for message in raw_messages:
if message.role != "system":
if message.api_role != "system":
# Mark assistant responses as simulated since this is a simulated conversation
if message.api_role == "assistant":
for piece in message.message_pieces:
piece._role = "simulated_assistant"
filtered_messages.append(message)

# Get the score from the result (there should be one score for the last turn)
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ async def _generate_subsequent_turn_prompt_async(self, objective: str) -> str:
target_messages = self._memory.get_conversation(conversation_id=self.objective_target_conversation_id)

# Extract the last assistant response
assistant_responses = [r for r in target_messages if r.get_piece().role == "assistant"]
assistant_responses = [r for r in target_messages if r.get_piece().api_role == "assistant"]
if not assistant_responses:
logger.error(f"No assistant responses found in the conversation {self.objective_target_conversation_id}.")
raise RuntimeError("Cannot proceed without an assistant response.")
Expand Down
13 changes: 7 additions & 6 deletions pyrit/executor/attack/printer/console_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ async def print_messages_async(
turn_number = 0
for message in messages:
# Increment turn number once per message with role="user"
if message.role == "user":
if message.api_role == "user":
turn_number += 1
# User message header
print()
self._print_colored("─" * self._width, Fore.BLUE)
self._print_colored(f"🔹 Turn {turn_number} - USER", Style.BRIGHT, Fore.BLUE)
self._print_colored("─" * self._width, Fore.BLUE)
elif message.role == "system":
elif message.api_role == "system":
# System message header (not counted as a turn)
print()
self._print_colored("─" * self._width, Fore.MAGENTA)
Expand All @@ -169,7 +169,8 @@ async def print_messages_async(
# Assistant or other role message header
print()
self._print_colored("─" * self._width, Fore.YELLOW)
self._print_colored(f"🔸 {message.role.upper()}", Style.BRIGHT, Fore.YELLOW)
role_label = "ASSISTANT (SIMULATED)" if message.is_simulated else message.api_role.upper()
self._print_colored(f"🔸 {role_label}", Style.BRIGHT, Fore.YELLOW)
self._print_colored("─" * self._width, Fore.YELLOW)

# Now print all pieces in this message
Expand All @@ -179,15 +180,15 @@ async def print_messages_async(
continue

# Handle converted values for user messages
if piece.role == "user" and piece.converted_value != piece.original_value:
if piece.api_role == "user" and piece.converted_value != piece.original_value:
self._print_colored(f"{self._indent} Original:", Fore.CYAN)
self._print_wrapped_text(piece.original_value, Fore.WHITE)
print()
self._print_colored(f"{self._indent} Converted:", Fore.CYAN)
self._print_wrapped_text(piece.converted_value, Fore.WHITE)
elif piece.role == "user":
elif piece.api_role == "user":
self._print_wrapped_text(piece.converted_value, Fore.BLUE)
elif piece.role == "system":
elif piece.api_role == "system":
self._print_wrapped_text(piece.converted_value, Fore.MAGENTA)
else:
self._print_wrapped_text(piece.converted_value, Fore.YELLOW)
Expand Down
5 changes: 3 additions & 2 deletions pyrit/executor/attack/printer/markdown_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def _get_conversation_markdown_async(
if not message.message_pieces:
continue

message_role = message.get_piece().role
message_role = message.get_piece().api_role

if message_role == "system":
markdown_lines.extend(self._format_system_message(message))
Expand Down Expand Up @@ -284,7 +284,8 @@ async def _format_assistant_message_async(self, *, message: Message) -> List[str
List[str]: List of markdown strings representing the response message.
"""
lines = []
role_name = message.message_pieces[0].role.capitalize()
piece = message.message_pieces[0]
role_name = "Assistant (Simulated)" if piece.is_simulated else piece.api_role.capitalize()

lines.append(f"\n#### {role_name}\n")

Expand Down
4 changes: 2 additions & 2 deletions pyrit/executor/promptgen/fuzzer/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,11 @@ def _print_conversations(self, result: FuzzerResult) -> None:
continue

for message in target_messages:
if message.role == "user":
if message.api_role == "user":
self._print_colored(f"{self._indent * 2} USER:", Style.BRIGHT, Fore.BLUE)
self._print_wrapped_text(message.converted_value, Fore.BLUE)
else:
self._print_colored(f"{self._indent * 2} {message.role.upper()}:", Style.BRIGHT, Fore.YELLOW)
self._print_colored(f"{self._indent * 2} {message.api_role.upper()}:", Style.BRIGHT, Fore.YELLOW)
self._print_wrapped_text(message.converted_value, Fore.YELLOW)

# Print scores if available
Expand Down
6 changes: 3 additions & 3 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def get_request_from_response(self, *, response: Message) -> Message:
Raises:
ValueError: If the response is not from an assistant role or has no preceding request.
"""
if response.role != "assistant":
if response.api_role != "assistant":
raise ValueError("The provided request is not a response (role must be 'assistant').")
if response.sequence < 1:
raise ValueError("The provided request does not have a preceding request (sequence < 1).")
Expand Down Expand Up @@ -628,7 +628,7 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) ->

length_of_sequence_to_remove = 0

if last_message.role == "system" or last_message.role == "user":
if last_message.api_role == "system" or last_message.api_role == "user":
length_of_sequence_to_remove = 1
else:
length_of_sequence_to_remove = 2
Expand Down Expand Up @@ -778,7 +778,7 @@ def get_chat_messages_with_conversation_id(self, *, conversation_id: str) -> Seq
Sequence[ChatMessage]: The list of chat messages.
"""
memory_entries = self.get_message_pieces(conversation_id=conversation_id)
return [ChatMessage(role=me.role, content=me.converted_value) for me in memory_entries] # type: ignore
return [ChatMessage(role=me.api_role, content=me.converted_value) for me in memory_entries] # type: ignore

def get_seeds(
self,
Expand Down
6 changes: 4 additions & 2 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ class PromptMemoryEntry(Base):
__tablename__ = "PromptMemoryEntries"
__table_args__ = {"extend_existing": True}
id = mapped_column(CustomUUID, nullable=False, primary_key=True)
role: Mapped[Literal["system", "user", "assistant", "tool", "developer"]] = mapped_column(String, nullable=False)
role: Mapped[Literal["system", "user", "assistant", "simulated_assistant", "tool", "developer"]] = mapped_column(
String, nullable=False
)
conversation_id = mapped_column(String, nullable=False)
sequence = mapped_column(INTEGER, nullable=False)
timestamp = mapped_column(DateTime, nullable=False)
Expand Down Expand Up @@ -192,7 +194,7 @@ def __init__(self, *, entry: MessagePiece):
entry (MessagePiece): The message piece to convert into a database entry.
"""
self.id = entry.id
self.role = entry.role
self.role = entry._role
self.conversation_id = entry.conversation_id
self.sequence = entry.sequence
self.timestamp = entry.timestamp
Expand Down
2 changes: 1 addition & 1 deletion pyrit/models/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyrit.models.literals import ChatMessageRole

ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "tool", "developer"]
ALLOWED_CHAT_MESSAGE_ROLES = ["system", "user", "assistant", "simulated_assistant", "tool", "developer"]


class ToolCall(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion pyrit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
"""
PromptResponseError = Literal["blocked", "none", "processing", "empty", "unknown"]

ChatMessageRole = Literal["system", "user", "assistant", "tool", "developer"]
ChatMessageRole = Literal["system", "user", "assistant", "simulated_assistant", "tool", "developer"]
Loading