diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 7dc67c584..958ecfb31 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -8,7 +8,17 @@ from mcp.client._memory import InMemoryTransport from mcp.client._transport import Transport -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + PromptListChangedFnT, + ResourceListChangedFnT, + SamplingFnT, + ToolListChangedFnT, +) from mcp.client.streamable_http import streamable_http_client from mcp.server import Server from mcp.server.mcpserver import MCPServer @@ -95,6 +105,15 @@ async def main(): elicitation_callback: ElicitationFnT | None = None """Callback for handling elicitation requests.""" + tool_list_changed_callback: ToolListChangedFnT | None = None + """Callback invoked when the server signals its tool list has changed.""" + + prompt_list_changed_callback: PromptListChangedFnT | None = None + """Callback invoked when the server signals its prompt list has changed.""" + + resource_list_changed_callback: ResourceListChangedFnT | None = None + """Callback invoked when the server signals its resource list has changed.""" + _session: ClientSession | None = field(init=False, default=None) _exit_stack: AsyncExitStack | None = field(init=False, default=None) _transport: Transport = field(init=False) @@ -126,6 +145,9 @@ async def __aenter__(self) -> Client: message_handler=self.message_handler, client_info=self.client_info, elicitation_callback=self.elicitation_callback, + tool_list_changed_callback=self.tool_list_changed_callback, + prompt_list_changed_callback=self.prompt_list_changed_callback, + resource_list_changed_callback=self.resource_list_changed_callback, ) ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a0ca751bd..12b75b585 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -47,6 +47,18 @@ class LoggingFnT(Protocol): async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch +class ResourceListChangedFnT(Protocol): + async def __call__(self) -> None: ... # pragma: no branch + + +class ToolListChangedFnT(Protocol): + async def __call__(self) -> None: ... # pragma: no branch + + +class PromptListChangedFnT(Protocol): + async def __call__(self) -> None: ... # pragma: no branch + + class MessageHandlerFnT(Protocol): async def __call__( self, @@ -95,6 +107,10 @@ async def _default_logging_callback( pass +async def _default_list_changed_callback() -> None: + pass + + ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) @@ -121,6 +137,9 @@ def __init__( *, sampling_capabilities: types.SamplingCapability | None = None, experimental_task_handlers: ExperimentalTaskHandlers | None = None, + tool_list_changed_callback: ToolListChangedFnT | None = None, + prompt_list_changed_callback: PromptListChangedFnT | None = None, + resource_list_changed_callback: ResourceListChangedFnT | None = None, ) -> None: super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) self._client_info = client_info or DEFAULT_CLIENT_INFO @@ -130,6 +149,9 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler + self._tool_list_changed_callback = tool_list_changed_callback or _default_list_changed_callback + self._prompt_list_changed_callback = prompt_list_changed_callback or _default_list_changed_callback + self._resource_list_changed_callback = resource_list_changed_callback or _default_list_changed_callback self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None self._experimental_features: ExperimentalClientFeatures | None = None @@ -470,6 +492,21 @@ async def _received_notification(self, notification: types.ServerNotification) - match notification: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) + case types.ToolListChangedNotification(): + try: + await self._tool_list_changed_callback() + except Exception: + logger.exception("Tool list changed callback raised an exception") + case types.PromptListChangedNotification(): + try: + await self._prompt_list_changed_callback() + except Exception: + logger.exception("Prompt list changed callback raised an exception") + case types.ResourceListChangedNotification(): + try: + await self._resource_list_changed_callback() + except Exception: + logger.exception("Resource list changed callback raised an exception") case types.ElicitCompleteNotification(params=params): # Handle elicitation completion notification # Clients MAY use this to retry requests or update UI diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..bd3516347 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -20,7 +20,16 @@ import mcp from mcp import types -from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ( + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + PromptListChangedFnT, + ResourceListChangedFnT, + SamplingFnT, + ToolListChangedFnT, +) from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters from mcp.client.streamable_http import streamable_http_client @@ -80,6 +89,9 @@ class ClientSessionParameters: logging_callback: LoggingFnT | None = None message_handler: MessageHandlerFnT | None = None client_info: types.Implementation | None = None + tool_list_changed_callback: ToolListChangedFnT | None = None + prompt_list_changed_callback: PromptListChangedFnT | None = None + resource_list_changed_callback: ResourceListChangedFnT | None = None class ClientSessionGroup: @@ -310,6 +322,9 @@ async def _establish_session( logging_callback=session_params.logging_callback, message_handler=session_params.message_handler, client_info=session_params.client_info, + tool_list_changed_callback=session_params.tool_list_changed_callback, + prompt_list_changed_callback=session_params.prompt_list_changed_callback, + resource_list_changed_callback=session_params.resource_list_changed_callback, ) ) diff --git a/tests/client/test_list_changed_callbacks.py b/tests/client/test_list_changed_callbacks.py new file mode 100644 index 000000000..edb8231b2 --- /dev/null +++ b/tests/client/test_list_changed_callbacks.py @@ -0,0 +1,291 @@ +"""Tests for list_changed notification callbacks in ClientSession.""" + +import anyio +import pytest + +from mcp import types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage + +pytestmark = pytest.mark.anyio + + +async def test_tool_list_changed_callback(): + """Verify that the client invokes the tool_list_changed callback when + the server sends a notifications/tools/list_changed notification.""" + callback_called = anyio.Event() + + async def on_tools_changed() -> None: + callback_called.set() + + async def _list_tools(_ctx: object, _params: object) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) # pragma: no cover + + server = Server( + name="ListChangedServer", + on_list_tools=_list_tools, + ) + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + + async with anyio.create_task_group() as tg: + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ListChangedServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(tools_changed=True), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}) + + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + tool_list_changed_callback=on_tools_changed, + ) as session: + await session.initialize() + + # Have the server send a tool list changed notification directly + await server_to_client_send.send( + SessionMessage( + message=types.JSONRPCNotification( + jsonrpc="2.0", + **types.ToolListChangedNotification().model_dump(by_alias=True, mode="json", exclude_none=True), + ), + ) + ) + + with anyio.fail_after(2): + await callback_called.wait() + + tg.cancel_scope.cancel() + + +async def test_prompt_list_changed_callback(): + """Verify the prompt_list_changed callback is invoked.""" + callback_called = anyio.Event() + + async def on_prompts_changed() -> None: + callback_called.set() + + server = Server(name="ListChangedServer") + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + + async with anyio.create_task_group() as tg: + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ListChangedServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(prompts_changed=True), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}) + + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + prompt_list_changed_callback=on_prompts_changed, + ) as session: + await session.initialize() + + await server_to_client_send.send( + SessionMessage( + message=types.JSONRPCNotification( + jsonrpc="2.0", + **types.PromptListChangedNotification().model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ), + ) + ) + + with anyio.fail_after(2): + await callback_called.wait() + + tg.cancel_scope.cancel() + + +async def test_resource_list_changed_callback(): + """Verify the resource_list_changed callback is invoked.""" + callback_called = anyio.Event() + + async def on_resources_changed() -> None: + callback_called.set() + + server = Server(name="ListChangedServer") + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + + async with anyio.create_task_group() as tg: + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ListChangedServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(resources_changed=True), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}) + + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + resource_list_changed_callback=on_resources_changed, + ) as session: + await session.initialize() + + await server_to_client_send.send( + SessionMessage( + message=types.JSONRPCNotification( + jsonrpc="2.0", + **types.ResourceListChangedNotification().model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ), + ) + ) + + with anyio.fail_after(2): + await callback_called.wait() + + tg.cancel_scope.cancel() + + +async def test_list_changed_default_no_error(): + """Verify that without callbacks, list_changed notifications are handled + silently (no errors, no hangs).""" + server = Server(name="ListChangedServer") + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + + async with anyio.create_task_group() as tg: + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ListChangedServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}) + + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session: + await session.initialize() + + # Send all three list_changed notifications — none should cause errors + for notification_cls in ( + types.ToolListChangedNotification, + types.PromptListChangedNotification, + types.ResourceListChangedNotification, + ): + await server_to_client_send.send( + SessionMessage( + message=types.JSONRPCNotification( + jsonrpc="2.0", + **notification_cls().model_dump(by_alias=True, mode="json", exclude_none=True), + ), + ) + ) + + # Give the session a moment to process + await anyio.sleep(0.1) + + tg.cancel_scope.cancel() + + +async def test_callback_exception_does_not_crash_session(): + """Verify that an exception in a list_changed callback is logged but does + not crash the client session.""" + + async def bad_callback() -> None: + raise RuntimeError("boom") + + server = Server(name="ListChangedServer") + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + + async with anyio.create_task_group() as tg: + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ListChangedServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}) + + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + tool_list_changed_callback=bad_callback, + prompt_list_changed_callback=bad_callback, + resource_list_changed_callback=bad_callback, + ) as session: + await session.initialize() + + # Send all three notification types — all callbacks will raise, + # but the session should survive. + for notification_cls in ( + types.ToolListChangedNotification, + types.PromptListChangedNotification, + types.ResourceListChangedNotification, + ): + await server_to_client_send.send( + SessionMessage( + message=types.JSONRPCNotification( + jsonrpc="2.0", + **notification_cls().model_dump(by_alias=True, mode="json", exclude_none=True), + ), + ) + ) + + # Session should still be alive — verify by waiting for processing + await anyio.sleep(0.1) + + tg.cancel_scope.cancel() diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..cffdc96b2 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -378,6 +378,9 @@ async def test_client_session_group_establish_session_parameterized( logging_callback=None, message_handler=None, client_info=None, + tool_list_changed_callback=None, + prompt_list_changed_callback=None, + resource_list_changed_callback=None, ) mock_raw_session_cm.__aenter__.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once()