Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions src/agents/extensions/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from agents.handoffs import Handoff


def _escape_label(name: str) -> str:
"""Escape a name for use inside a Graphviz double-quoted ID or label.

Backslashes are escaped first, then double quotes, so a name containing
either character does not terminate the DOT string early or produce
malformed output.
"""
return name.replace("\\", "\\\\").replace('"', '\\"')


def get_main_graph(agent: Agent) -> str:
"""
Generates the main graph structure in DOT format for the given agent.
Expand Down Expand Up @@ -59,34 +69,42 @@ def get_all_nodes(
"fillcolor=lightblue, width=0.5, height=0.3];"
)
# Ensure parent agent node is colored
name = _escape_label(agent.name)
parts.append(
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
f'"{name}" [label="{name}", '
"shape=box, style=filled, "
"fillcolor=lightyellow, width=1.5, height=0.8];"
)

for tool in agent.tools:
name = _escape_label(tool.name)
parts.append(
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
f"fillcolor=lightgreen, width=0.5, height=0.3];"
f'"{name}" [label="{name}", '
"shape=ellipse, style=filled, "
"fillcolor=lightgreen, width=0.5, height=0.3];"
)

for mcp_server in agent.mcp_servers:
name = _escape_label(mcp_server.name)
parts.append(
f'"{mcp_server.name}" [label="{mcp_server.name}", shape=box, style=filled, '
f"fillcolor=lightgrey, width=1, height=0.5];"
f'"{name}" [label="{name}", '
"shape=box, style=filled, "
"fillcolor=lightgrey, width=1, height=0.5];"
)

for handoff in agent.handoffs:
if isinstance(handoff, Handoff):
name = _escape_label(handoff.agent_name)
parts.append(
f'"{handoff.agent_name}" [label="{handoff.agent_name}", '
f'"{name}" [label="{name}", '
f"shape=box, style=filled, style=rounded, "
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
if isinstance(handoff, Agent):
if handoff.name not in visited:
name = _escape_label(handoff.name)
parts.append(
f'"{handoff.name}" [label="{handoff.name}", '
f'"{name}" [label="{name}", '
f"shape=box, style=filled, style=rounded, "
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
Expand Down Expand Up @@ -116,30 +134,34 @@ def get_all_edges(

parts = []

agent_name = _escape_label(agent.name)

if not parent:
parts.append(f'"__start__" -> "{agent.name}";')
parts.append(f'"__start__" -> "{agent_name}";')

for tool in agent.tools:
tool_name = _escape_label(tool.name)
parts.append(f"""
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
"{agent_name}" -> "{tool_name}" [style=dotted, penwidth=1.5];
"{tool_name}" -> "{agent_name}" [style=dotted, penwidth=1.5];""")

for mcp_server in agent.mcp_servers:
server_name = _escape_label(mcp_server.name)
parts.append(f"""
"{agent.name}" -> "{mcp_server.name}" [style=dashed, penwidth=1.5];
"{mcp_server.name}" -> "{agent.name}" [style=dashed, penwidth=1.5];""")
"{agent_name}" -> "{server_name}" [style=dashed, penwidth=1.5];
"{server_name}" -> "{agent_name}" [style=dashed, penwidth=1.5];""")

for handoff in agent.handoffs:
if isinstance(handoff, Handoff):
parts.append(f"""
"{agent.name}" -> "{handoff.agent_name}";""")
"{agent_name}" -> "{_escape_label(handoff.agent_name)}";""")
if isinstance(handoff, Agent):
parts.append(f"""
"{agent.name}" -> "{handoff.name}";""")
"{agent_name}" -> "{_escape_label(handoff.name)}";""")
parts.append(get_all_edges(handoff, agent, visited))

if not agent.handoffs:
parts.append(f'"{agent.name}" -> "__end__";')
parts.append(f'"{agent_name}" -> "__end__";')

return "".join(parts)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,27 @@ def test_cycle_detection():
assert '"B" -> "A"' in edges


def test_names_with_quotes_and_backslashes_are_escaped(mock_agent):
"""Names containing double quotes or backslashes must be escaped in DOT.

Otherwise an embedded quote closes the Graphviz identifier early and
produces a malformed graph. Backslashes are escaped first, then quotes.
"""
mock_agent.name = 'Weird"Name'
mock_agent.tools[0].name = "Back\\slash"

nodes = get_all_nodes(mock_agent)
edges = get_all_edges(mock_agent)

# The quote is backslash-escaped and the bare unescaped form is gone.
assert '"Weird\\"Name" [label="Weird\\"Name"' in nodes
assert '"Weird"Name"' not in nodes
# The backslash is doubled.
assert '"Back\\\\slash"' in nodes
# Edges escape names too, so the start arrow points at the escaped id.
assert '"__start__" -> "Weird\\"Name";' in edges


def test_draw_graph_with_real_agent_no_handoffs():
"""Test that draw_graph works with a real Agent object without handoffs.

Expand Down