Skip to content
Open
5 changes: 2 additions & 3 deletions src/mcp/client/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from types import TracebackType
from typing import Any

import anyio

from mcp.client._transport import TransportStreams
from mcp.server import Server
from mcp.server.mcpserver import MCPServer
from mcp.shared._exception_utils import create_task_group as _create_task_group
from mcp.shared.memory import create_client_server_memory_streams


Expand Down Expand Up @@ -48,7 +47,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:
client_read, client_write = client_streams
server_read, server_write = server_streams

async with anyio.create_task_group() as tg:
async with _create_task_group() as tg:
# Start server in background
tg.start_soon(
lambda: actual_server.run(
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from httpx_sse._exceptions import SSEError

from mcp import types
from mcp.shared._exception_utils import create_task_group
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage

Expand Down Expand Up @@ -60,7 +61,7 @@ async def sse_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async with anyio.create_task_group() as tg:
async with create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_windows_executable_command,
terminate_windows_process_tree,
)
from mcp.shared._exception_utils import create_task_group
from mcp.shared.message import SessionMessage

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,7 +178,7 @@ async def stdin_writer():
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()

async with anyio.create_task_group() as tg, process:
async with create_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic import ValidationError

from mcp.client._transport import TransportStreams
from mcp.shared._exception_utils import create_task_group
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
Expand Down Expand Up @@ -546,7 +547,7 @@ async def streamable_http_client(

transport = StreamableHTTPTransport(url)

async with anyio.create_task_group() as tg:
async with create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

Expand Down
3 changes: 2 additions & 1 deletion src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from websockets.typing import Subprotocol

from mcp import types
from mcp.shared._exception_utils import create_task_group
from mcp.shared.message import SessionMessage


Expand Down Expand Up @@ -68,7 +69,7 @@ async def ws_writer():
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True)
await ws.send(json.dumps(msg_dict))

async with anyio.create_task_group() as tg:
async with create_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
Expand Down
62 changes: 62 additions & 0 deletions src/mcp/shared/_exception_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Utilities for collapsing ExceptionGroups from anyio task group cancellations.

When a task group has one real failure and N cancelled siblings, anyio wraps them
all in a BaseExceptionGroup. This makes it hard for callers to classify the root
cause. These utilities extract the single real error when possible.
"""

from __future__ import annotations

import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

import anyio
from anyio.abc import TaskGroup

if sys.version_info < (3, 11): # pragma: lax no cover
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover


def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> BaseException:
"""Collapse a single-error exception group into the underlying exception.

When a task in an anyio task group fails, sibling tasks are cancelled,
producing ``Cancelled`` exceptions. The task group then wraps everything
in a ``BaseExceptionGroup``. If there is exactly one non-cancellation
error, this function returns it directly so callers can handle it without
unwrapping.

Args:
exc_group: The exception group to collapse.

Returns:
The single non-cancellation exception if there is exactly one,
otherwise the original exception group unchanged.
"""
cancelled_class = anyio.get_cancelled_exc_class()
real_errors: list[BaseException] = [exc for exc in exc_group.exceptions if not isinstance(exc, cancelled_class)]

if len(real_errors) == 1:
return real_errors[0]

return exc_group


@asynccontextmanager
async def create_task_group() -> AsyncIterator[TaskGroup]:
"""Create an anyio task group that collapses single-error exception groups.

Drop-in replacement for ``anyio.create_task_group()`` that automatically
unwraps ``BaseExceptionGroup`` when there is exactly one non-cancellation
error. This makes error handling transparent for callers — they receive
the original exception instead of a wrapped group.
"""
try:
async with anyio.create_task_group() as tg:
yield tg
except BaseExceptionGroup as eg:
collapsed = collapse_exception_group(eg)
if collapsed is not eg:
raise collapsed from eg
raise
13 changes: 12 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import sys
from collections.abc import Callable
from contextlib import AsyncExitStack
from types import TracebackType
Expand All @@ -11,6 +12,10 @@
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self

if sys.version_info < (3, 11): # pragma: lax no cover
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover

from mcp.shared._exception_utils import collapse_exception_group
from mcp.shared.exceptions import MCPError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.response_router import ResponseRouter
Expand Down Expand Up @@ -228,7 +233,13 @@ async def __aexit__(
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
try:
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
except BaseExceptionGroup as eg:
collapsed = collapse_exception_group(eg)
if collapsed is not eg:
raise collapsed from eg
raise # pragma: no cover

async def send_request(
self,
Expand Down
132 changes: 132 additions & 0 deletions tests/shared/test_exception_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Tests for exception group collapsing utilities."""

import sys

import anyio
import pytest

if sys.version_info < (3, 11): # pragma: lax no cover
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover

from mcp.shared._exception_utils import collapse_exception_group, create_task_group


class TestCollapseExceptionGroup:
"""Tests for the collapse_exception_group function."""

@pytest.mark.anyio
async def test_single_real_error_with_cancelled(self) -> None:
"""A single real error alongside Cancelled exceptions should be extracted."""
real_error = RuntimeError("connection failed")
cancelled = anyio.get_cancelled_exc_class()()

group = BaseExceptionGroup("test", [real_error, cancelled])
result = collapse_exception_group(group)

assert result is real_error

@pytest.mark.anyio
async def test_single_real_error_only(self) -> None:
"""A single real error without Cancelled should be extracted."""
real_error = ValueError("bad value")

group = BaseExceptionGroup("test", [real_error])
result = collapse_exception_group(group)

assert result is real_error

@pytest.mark.anyio
async def test_multiple_real_errors_preserved(self) -> None:
"""Multiple non-cancellation errors should keep the group intact."""
err1 = RuntimeError("first")
err2 = ValueError("second")

group = BaseExceptionGroup("test", [err1, err2])
result = collapse_exception_group(group)

assert result is group

@pytest.mark.anyio
async def test_all_cancelled_preserved(self) -> None:
"""All-cancelled groups should be returned as-is."""
cancelled_class = anyio.get_cancelled_exc_class()
group = BaseExceptionGroup("test", [cancelled_class(), cancelled_class()])
result = collapse_exception_group(group)

assert result is group

@pytest.mark.anyio
async def test_multiple_cancelled_one_real(self) -> None:
"""One real error with multiple Cancelled should extract the real error."""
cancelled_class = anyio.get_cancelled_exc_class()
real_error = ConnectionError("lost connection")

group = BaseExceptionGroup("test", [cancelled_class(), real_error, cancelled_class()])
result = collapse_exception_group(group)

assert result is real_error


class TestCreateTaskGroup:
"""Tests for the create_task_group context manager."""

@pytest.mark.anyio
async def test_single_failure_unwrapped(self) -> None:
"""A single task failure should propagate the original exception, not a group."""
with pytest.raises(RuntimeError, match="task failed"):
async with create_task_group() as tg:

async def failing_task() -> None:
raise RuntimeError("task failed")

async def long_task() -> None:
await anyio.sleep(100)

tg.start_soon(failing_task)
tg.start_soon(long_task)

@pytest.mark.anyio
async def test_no_failure_clean_exit(self) -> None:
"""Task group with no failures should exit cleanly."""
results: list[int] = []
async with create_task_group() as tg:

async def worker(n: int) -> None:
results.append(n)

tg.start_soon(worker, 1)
tg.start_soon(worker, 2)

assert sorted(results) == [1, 2]

@pytest.mark.anyio
async def test_chained_cause(self) -> None:
"""The collapsed exception should chain to the original group via __cause__."""
with pytest.raises(RuntimeError) as exc_info:
async with create_task_group() as tg:

async def failing_task() -> None:
raise RuntimeError("root cause")

async def long_task() -> None:
await anyio.sleep(100)

tg.start_soon(failing_task)
tg.start_soon(long_task)

assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)

@pytest.mark.anyio
async def test_multiple_failures_raises_group(self) -> None:
"""Multiple real task failures should raise as a BaseExceptionGroup."""
with pytest.raises(BaseExceptionGroup):
async with create_task_group() as tg:

async def fail_a() -> None:
raise RuntimeError("error A")

async def fail_b() -> None:
raise ValueError("error B")

tg.start_soon(fail_a)
tg.start_soon(fail_b)