@@ -415,3 +415,30 @@ async def test_client_session_group_establish_session_parameterized(
415415 # 3. Assert returned values
416416 assert returned_server_info is mock_initialize_result .server_info
417417 assert returned_session is mock_entered_session
418+
419+
420+ @pytest .mark .anyio
421+ async def test_client_session_group_establish_session_closes_stack_on_initialize_error ():
422+ group_exit_stack = mock .AsyncMock (spec = contextlib .AsyncExitStack )
423+ session_stack = mock .AsyncMock (spec = contextlib .AsyncExitStack )
424+ mock_read_stream = mock .AsyncMock (name = "Read" )
425+ mock_write_stream = mock .AsyncMock (name = "Write" )
426+ mock_session = mock .AsyncMock (spec = mcp .ClientSession )
427+ mock_session .initialize .side_effect = RuntimeError ("initialize failed" )
428+ session_stack .enter_async_context .side_effect = [
429+ (mock_read_stream , mock_write_stream ),
430+ mock_session ,
431+ ]
432+
433+ group = ClientSessionGroup (exit_stack = group_exit_stack )
434+
435+ with (
436+ mock .patch ("mcp.client.session_group.contextlib.AsyncExitStack" , return_value = session_stack ),
437+ mock .patch ("mcp.client.session_group.mcp.stdio_client" , return_value = mock .AsyncMock ()),
438+ mock .patch ("mcp.client.session_group.mcp.ClientSession" , return_value = mock .AsyncMock ()),
439+ pytest .raises (RuntimeError , match = "initialize failed" ),
440+ ):
441+ await group ._establish_session (StdioServerParameters (command = "test" ), ClientSessionParameters ())
442+
443+ session_stack .aclose .assert_awaited_once ()
444+ group_exit_stack .enter_async_context .assert_not_awaited ()
0 commit comments