diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a..c738b877f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -46,6 +46,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + ProgressFnT, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -245,6 +246,7 @@ async def create_message( tools: None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult: """Overload: Without tools, returns single content.""" ... @@ -264,6 +266,7 @@ async def create_message( tools: list[types.Tool], tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResultWithTools: """Overload: With tools, returns array-capable content.""" ... @@ -282,6 +285,7 @@ async def create_message( tools: list[types.Tool] | None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: """Send a sampling/create_message request. @@ -338,11 +342,13 @@ async def create_message( request=request, result_type=types.CreateMessageResultWithTools, metadata=metadata_obj, + progress_callback=progress_callback, ) return await self.send_request( request=request, result_type=types.CreateMessageResult, metadata=metadata_obj, + progress_callback=progress_callback, ) async def list_roots(self) -> types.ListRootsResult: @@ -359,6 +365,7 @@ async def elicit( message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -374,13 +381,16 @@ async def elicit( This method is deprecated in favor of elicit_form(). It remains for backward compatibility but new code should use elicit_form(). """ - return await self.elicit_form(message, requested_schema, related_request_id) + return await self.elicit_form( + message, requested_schema, related_request_id, progress_callback=progress_callback + ) async def elicit_form( self, message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -406,6 +416,7 @@ async def elicit_form( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def elicit_url( @@ -414,6 +425,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a URL mode elicitation/create request. @@ -444,6 +456,7 @@ async def elicit_url( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def send_ping(self) -> types.EmptyResult: # pragma: no cover