Skip to content

Commit 603909d

Browse files
g97iulio1609Copilot
andcommitted
fix: collapse single-error ExceptionGroups from task group cancellations
When a task in an anyio task group fails, sibling tasks are cancelled and the resulting Cancelled exceptions are wrapped alongside the real error in a BaseExceptionGroup. This makes it extremely difficult for callers to classify the root cause of failures. Added collapse_exception_group() utility and a drop-in create_task_group() context manager that automatically unwraps single-error exception groups. When there is exactly one non-cancellation error, callers now receive the original exception directly instead of a wrapped group. Applied to all client-facing code paths: - BaseSession.__aexit__ (affects all session-based operations) - Client transports: stdio, SSE, streamable HTTP, websocket, memory Multiple real errors (non-cancellation) are preserved as-is in the exception group. Fixes #2114 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 62575ed commit 603909d

File tree

8 files changed

+193
-6
lines changed

8 files changed

+193
-6
lines changed

src/mcp/client/_memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mcp.client._transport import TransportStreams
1313
from mcp.server import Server
1414
from mcp.server.mcpserver import MCPServer
15+
from mcp.shared._exception_utils import create_task_group as _create_task_group
1516
from mcp.shared.memory import create_client_server_memory_streams
1617

1718

@@ -48,7 +49,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
4849
client_read, client_write = client_streams
4950
server_read, server_write = server_streams
5051

51-
async with anyio.create_task_group() as tg:
52+
async with _create_task_group() as tg:
5253
# Start server in background
5354
tg.start_soon(
5455
lambda: actual_server.run(

src/mcp/client/sse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from httpx_sse._exceptions import SSEError
1313

1414
from mcp import types
15+
from mcp.shared._exception_utils import create_task_group
1516
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1617
from mcp.shared.message import SessionMessage
1718

@@ -60,7 +61,7 @@ async def sse_client(
6061
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
6162
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
6263

63-
async with anyio.create_task_group() as tg:
64+
async with create_task_group() as tg:
6465
try:
6566
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
6667
async with httpx_client_factory(

src/mcp/client/stdio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_windows_executable_command,
2121
terminate_windows_process_tree,
2222
)
23+
from mcp.shared._exception_utils import create_task_group
2324
from mcp.shared.message import SessionMessage
2425

2526
logger = logging.getLogger(__name__)
@@ -177,7 +178,7 @@ async def stdin_writer():
177178
except anyio.ClosedResourceError: # pragma: no cover
178179
await anyio.lowlevel.checkpoint()
179180

180-
async with anyio.create_task_group() as tg, process:
181+
async with create_task_group() as tg, process:
181182
tg.start_soon(stdout_reader)
182183
tg.start_soon(stdin_writer)
183184
try:

src/mcp/client/streamable_http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pydantic import ValidationError
1717

1818
from mcp.client._transport import TransportStreams
19+
from mcp.shared._exception_utils import create_task_group
1920
from mcp.shared._httpx_utils import create_mcp_http_client
2021
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2122
from mcp.types import (
@@ -546,7 +547,7 @@ async def streamable_http_client(
546547

547548
transport = StreamableHTTPTransport(url)
548549

549-
async with anyio.create_task_group() as tg:
550+
async with create_task_group() as tg:
550551
try:
551552
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
552553

src/mcp/client/websocket.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from websockets.typing import Subprotocol
1010

1111
from mcp import types
12+
from mcp.shared._exception_utils import create_task_group
1213
from mcp.shared.message import SessionMessage
1314

1415

@@ -68,7 +69,7 @@ async def ws_writer():
6869
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True)
6970
await ws.send(json.dumps(msg_dict))
7071

71-
async with anyio.create_task_group() as tg:
72+
async with create_task_group() as tg:
7273
# Start reader and writer tasks
7374
tg.start_soon(ws_reader)
7475
tg.start_soon(ws_writer)

src/mcp/shared/_exception_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Utilities for collapsing ExceptionGroups from anyio task group cancellations.
2+
3+
When a task group has one real failure and N cancelled siblings, anyio wraps them
4+
all in a BaseExceptionGroup. This makes it hard for callers to classify the root
5+
cause. These utilities extract the single real error when possible.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from collections.abc import AsyncIterator
11+
from contextlib import asynccontextmanager
12+
13+
import anyio
14+
from anyio.abc import TaskGroup
15+
16+
17+
def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> BaseException:
18+
"""Collapse a single-error exception group into the underlying exception.
19+
20+
When a task in an anyio task group fails, sibling tasks are cancelled,
21+
producing ``Cancelled`` exceptions. The task group then wraps everything
22+
in a ``BaseExceptionGroup``. If there is exactly one non-cancellation
23+
error, this function returns it directly so callers can handle it without
24+
unwrapping.
25+
26+
Args:
27+
exc_group: The exception group to collapse.
28+
29+
Returns:
30+
The single non-cancellation exception if there is exactly one,
31+
otherwise the original exception group unchanged.
32+
"""
33+
cancelled_class = anyio.get_cancelled_exc_class()
34+
real_errors: list[BaseException] = [
35+
exc for exc in exc_group.exceptions if not isinstance(exc, cancelled_class)
36+
]
37+
38+
if len(real_errors) == 1:
39+
return real_errors[0]
40+
41+
return exc_group
42+
43+
44+
@asynccontextmanager
45+
async def create_task_group() -> AsyncIterator[TaskGroup]:
46+
"""Create an anyio task group that collapses single-error exception groups.
47+
48+
Drop-in replacement for ``anyio.create_task_group()`` that automatically
49+
unwraps ``BaseExceptionGroup`` when there is exactly one non-cancellation
50+
error. This makes error handling transparent for callers — they receive
51+
the original exception instead of a wrapped group.
52+
"""
53+
try:
54+
async with anyio.create_task_group() as tg:
55+
yield tg
56+
except BaseExceptionGroup as eg:
57+
collapsed = collapse_exception_group(eg)
58+
if collapsed is not eg:
59+
raise collapsed from eg
60+
raise

src/mcp/shared/session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic import BaseModel, TypeAdapter
1212
from typing_extensions import Self
1313

14+
from mcp.shared._exception_utils import collapse_exception_group
1415
from mcp.shared.exceptions import MCPError
1516
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
1617
from mcp.shared.response_router import ResponseRouter
@@ -228,7 +229,13 @@ async def __aexit__(
228229
# would be very surprising behavior), so make sure to cancel the tasks
229230
# in the task group.
230231
self._task_group.cancel_scope.cancel()
231-
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
232+
try:
233+
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
234+
except BaseExceptionGroup as eg:
235+
collapsed = collapse_exception_group(eg)
236+
if collapsed is not eg:
237+
raise collapsed from eg
238+
raise
232239

233240
async def send_request(
234241
self,
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Tests for exception group collapsing utilities."""
2+
3+
import pytest
4+
5+
import anyio
6+
7+
from mcp.shared._exception_utils import collapse_exception_group, create_task_group
8+
9+
10+
class TestCollapseExceptionGroup:
11+
"""Tests for the collapse_exception_group function."""
12+
13+
@pytest.mark.anyio
14+
async def test_single_real_error_with_cancelled(self) -> None:
15+
"""A single real error alongside Cancelled exceptions should be extracted."""
16+
real_error = RuntimeError("connection failed")
17+
cancelled = anyio.get_cancelled_exc_class()()
18+
19+
group = BaseExceptionGroup("test", [real_error, cancelled])
20+
result = collapse_exception_group(group)
21+
22+
assert result is real_error
23+
24+
@pytest.mark.anyio
25+
async def test_single_real_error_only(self) -> None:
26+
"""A single real error without Cancelled should be extracted."""
27+
real_error = ValueError("bad value")
28+
29+
group = BaseExceptionGroup("test", [real_error])
30+
result = collapse_exception_group(group)
31+
32+
assert result is real_error
33+
34+
@pytest.mark.anyio
35+
async def test_multiple_real_errors_preserved(self) -> None:
36+
"""Multiple non-cancellation errors should keep the group intact."""
37+
err1 = RuntimeError("first")
38+
err2 = ValueError("second")
39+
40+
group = BaseExceptionGroup("test", [err1, err2])
41+
result = collapse_exception_group(group)
42+
43+
assert result is group
44+
45+
@pytest.mark.anyio
46+
async def test_all_cancelled_preserved(self) -> None:
47+
"""All-cancelled groups should be returned as-is."""
48+
cancelled_class = anyio.get_cancelled_exc_class()
49+
group = BaseExceptionGroup("test", [cancelled_class(), cancelled_class()])
50+
result = collapse_exception_group(group)
51+
52+
assert result is group
53+
54+
@pytest.mark.anyio
55+
async def test_multiple_cancelled_one_real(self) -> None:
56+
"""One real error with multiple Cancelled should extract the real error."""
57+
cancelled_class = anyio.get_cancelled_exc_class()
58+
real_error = ConnectionError("lost connection")
59+
60+
group = BaseExceptionGroup(
61+
"test", [cancelled_class(), real_error, cancelled_class()]
62+
)
63+
result = collapse_exception_group(group)
64+
65+
assert result is real_error
66+
67+
68+
class TestCreateTaskGroup:
69+
"""Tests for the create_task_group context manager."""
70+
71+
@pytest.mark.anyio
72+
async def test_single_failure_unwrapped(self) -> None:
73+
"""A single task failure should propagate the original exception, not a group."""
74+
with pytest.raises(RuntimeError, match="task failed"):
75+
async with create_task_group() as tg:
76+
77+
async def failing_task() -> None:
78+
raise RuntimeError("task failed")
79+
80+
async def long_task() -> None:
81+
await anyio.sleep(100)
82+
83+
tg.start_soon(failing_task)
84+
tg.start_soon(long_task)
85+
86+
@pytest.mark.anyio
87+
async def test_no_failure_clean_exit(self) -> None:
88+
"""Task group with no failures should exit cleanly."""
89+
results: list[int] = []
90+
async with create_task_group() as tg:
91+
92+
async def worker(n: int) -> None:
93+
results.append(n)
94+
95+
tg.start_soon(worker, 1)
96+
tg.start_soon(worker, 2)
97+
98+
assert sorted(results) == [1, 2]
99+
100+
@pytest.mark.anyio
101+
async def test_chained_cause(self) -> None:
102+
"""The collapsed exception should chain to the original group via __cause__."""
103+
with pytest.raises(RuntimeError) as exc_info:
104+
async with create_task_group() as tg:
105+
106+
async def failing_task() -> None:
107+
raise RuntimeError("root cause")
108+
109+
async def long_task() -> None:
110+
await anyio.sleep(100)
111+
112+
tg.start_soon(failing_task)
113+
tg.start_soon(long_task)
114+
115+
assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)

0 commit comments

Comments
 (0)