Skip to content
24 changes: 23 additions & 1 deletion src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

from mcp.client._memory import InMemoryTransport
from mcp.client._transport import Transport
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
PromptListChangedFnT,
ResourceListChangedFnT,
SamplingFnT,
ToolListChangedFnT,
)
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server
from mcp.server.mcpserver import MCPServer
Expand Down Expand Up @@ -95,6 +105,15 @@ async def main():
elicitation_callback: ElicitationFnT | None = None
"""Callback for handling elicitation requests."""

tool_list_changed_callback: ToolListChangedFnT | None = None
"""Callback invoked when the server signals its tool list has changed."""

prompt_list_changed_callback: PromptListChangedFnT | None = None
"""Callback invoked when the server signals its prompt list has changed."""

resource_list_changed_callback: ResourceListChangedFnT | None = None
"""Callback invoked when the server signals its resource list has changed."""

_session: ClientSession | None = field(init=False, default=None)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
_transport: Transport = field(init=False)
Expand Down Expand Up @@ -126,6 +145,9 @@ async def __aenter__(self) -> Client:
message_handler=self.message_handler,
client_info=self.client_info,
elicitation_callback=self.elicitation_callback,
tool_list_changed_callback=self.tool_list_changed_callback,
prompt_list_changed_callback=self.prompt_list_changed_callback,
resource_list_changed_callback=self.resource_list_changed_callback,
)
)

Expand Down
37 changes: 37 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class LoggingFnT(Protocol):
async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch


class ResourceListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class ToolListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class PromptListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class MessageHandlerFnT(Protocol):
async def __call__(
self,
Expand Down Expand Up @@ -95,6 +107,10 @@ async def _default_logging_callback(
pass


async def _default_list_changed_callback() -> None:
pass


ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)


Expand All @@ -121,6 +137,9 @@ def __init__(
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
tool_list_changed_callback: ToolListChangedFnT | None = None,
prompt_list_changed_callback: PromptListChangedFnT | None = None,
resource_list_changed_callback: ResourceListChangedFnT | None = None,
) -> None:
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
self._client_info = client_info or DEFAULT_CLIENT_INFO
Expand All @@ -130,6 +149,9 @@ def __init__(
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_list_changed_callback = tool_list_changed_callback or _default_list_changed_callback
self._prompt_list_changed_callback = prompt_list_changed_callback or _default_list_changed_callback
self._resource_list_changed_callback = resource_list_changed_callback or _default_list_changed_callback
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
Expand Down Expand Up @@ -470,6 +492,21 @@ async def _received_notification(self, notification: types.ServerNotification) -
match notification:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case types.ToolListChangedNotification():
try:
await self._tool_list_changed_callback()
except Exception:
logger.exception("Tool list changed callback raised an exception")
case types.PromptListChangedNotification():
try:
await self._prompt_list_changed_callback()
except Exception:
logger.exception("Prompt list changed callback raised an exception")
case types.ResourceListChangedNotification():
try:
await self._resource_list_changed_callback()
except Exception:
logger.exception("Resource list changed callback raised an exception")
case types.ElicitCompleteNotification(params=params):
# Handle elicitation completion notification
# Clients MAY use this to retry requests or update UI
Expand Down
17 changes: 16 additions & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@

import mcp
from mcp import types
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.session import (
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
PromptListChangedFnT,
ResourceListChangedFnT,
SamplingFnT,
ToolListChangedFnT,
)
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamable_http_client
Expand Down Expand Up @@ -80,6 +89,9 @@ class ClientSessionParameters:
logging_callback: LoggingFnT | None = None
message_handler: MessageHandlerFnT | None = None
client_info: types.Implementation | None = None
tool_list_changed_callback: ToolListChangedFnT | None = None
prompt_list_changed_callback: PromptListChangedFnT | None = None
resource_list_changed_callback: ResourceListChangedFnT | None = None


class ClientSessionGroup:
Expand Down Expand Up @@ -310,6 +322,9 @@ async def _establish_session(
logging_callback=session_params.logging_callback,
message_handler=session_params.message_handler,
client_info=session_params.client_info,
tool_list_changed_callback=session_params.tool_list_changed_callback,
prompt_list_changed_callback=session_params.prompt_list_changed_callback,
resource_list_changed_callback=session_params.resource_list_changed_callback,
)
)

Expand Down
Loading