@@ -1102,6 +1102,131 @@ async def throw_error():
11021102 # cancellation happens here and error is more understandable
11031103 await asyncio .sleep (0 )
11041104
1105+ async def test_taskgroup_cancel_children (self ):
1106+ # (asserting that TimeoutError is not raised)
1107+ async with asyncio .timeout (1 ):
1108+ async with asyncio .TaskGroup () as tg :
1109+ tg .create_task (asyncio .sleep (10 ))
1110+ tg .create_task (asyncio .sleep (10 ))
1111+ await asyncio .sleep (0 )
1112+ tg .cancel ()
1113+
1114+ async def test_taskgroup_cancel_body (self ):
1115+ count = 0
1116+ async with asyncio .TaskGroup () as tg :
1117+ tg .cancel ()
1118+ count += 1
1119+ await asyncio .sleep (0 )
1120+ count += 1
1121+ self .assertEqual (count , 1 )
1122+
1123+ async def test_taskgroup_cancel_idempotent (self ):
1124+ count = 0
1125+ async with asyncio .TaskGroup () as tg :
1126+ tg .cancel ()
1127+ tg .cancel ()
1128+ count += 1
1129+ await asyncio .sleep (0 )
1130+ count += 1
1131+ self .assertEqual (count , 1 )
1132+
1133+ async def test_taskgroup_cancel_after_exit (self ):
1134+ async with asyncio .TaskGroup () as tg :
1135+ await asyncio .sleep (0 )
1136+ # (asserting that exception is not raised)
1137+ tg .cancel ()
1138+
1139+ async def test_taskgroup_cancel_before_enter (self ):
1140+ tg = asyncio .TaskGroup ()
1141+ tg .cancel ()
1142+ count = 0
1143+ async with tg :
1144+ count += 1
1145+ await asyncio .sleep (0 )
1146+ count += 1
1147+ self .assertEqual (count , 1 )
1148+
1149+ async def test_taskgroup_cancel_before_create_task (self ):
1150+ async with asyncio .TaskGroup () as tg :
1151+ tg .cancel ()
1152+ # TODO: This behavior is not ideal. We'd rather have no exception
1153+ # raised, and the child task run until the first await.
1154+ with self .assertRaises (RuntimeError ):
1155+ tg .create_task (asyncio .sleep (1 ))
1156+
1157+ async def test_taskgroup_cancel_before_exception (self ):
1158+ async def raise_exc (parent_tg : asyncio .TaskGroup ):
1159+ parent_tg .cancel ()
1160+ raise RuntimeError
1161+
1162+ with self .assertRaises (ExceptionGroup ):
1163+ async with asyncio .TaskGroup () as tg :
1164+ tg .create_task (raise_exc (tg ))
1165+ await asyncio .sleep (1 )
1166+
1167+ async def test_taskgroup_cancel_after_exception (self ):
1168+ async def raise_exc (parent_tg : asyncio .TaskGroup ):
1169+ try :
1170+ raise RuntimeError
1171+ finally :
1172+ parent_tg .cancel ()
1173+
1174+ with self .assertRaises (ExceptionGroup ):
1175+ async with asyncio .TaskGroup () as tg :
1176+ tg .create_task (raise_exc (tg ))
1177+ await asyncio .sleep (1 )
1178+
1179+ async def test_taskgroup_body_cancel_before_exception (self ):
1180+ with self .assertRaises (ExceptionGroup ):
1181+ async with asyncio .TaskGroup () as tg :
1182+ tg .cancel ()
1183+ raise RuntimeError
1184+
1185+ async def test_taskgroup_body_cancel_after_exception (self ):
1186+ with self .assertRaises (ExceptionGroup ):
1187+ async with asyncio .TaskGroup () as tg :
1188+ try :
1189+ raise RuntimeError
1190+ finally :
1191+ tg .cancel ()
1192+
1193+ async def test_taskgroup_cancel_one_winner (self ):
1194+ async def race (* fns ):
1195+ outcome = None
1196+ async def run (fn ):
1197+ nonlocal outcome
1198+ outcome = await fn ()
1199+ tg .cancel ()
1200+
1201+ async with asyncio .TaskGroup () as tg :
1202+ for fn in fns :
1203+ tg .create_task (run (fn ))
1204+ return outcome
1205+
1206+ event = asyncio .Event ()
1207+ record = []
1208+ async def fn_1 ():
1209+ record .append ("1 started" )
1210+ await event .wait ()
1211+ record .append ("1 finished" )
1212+ return 1
1213+
1214+ async def fn_2 ():
1215+ record .append ("2 started" )
1216+ await event .wait ()
1217+ record .append ("2 finished" )
1218+ return 2
1219+
1220+ async def fn_3 ():
1221+ record .append ("3 started" )
1222+ event .set ()
1223+ await asyncio .sleep (10 )
1224+ record .append ("3 finished" )
1225+ return 3
1226+
1227+ self .assertEqual (await race (fn_1 , fn_2 , fn_3 ), 1 )
1228+ self .assertListEqual (record , ["1 started" , "2 started" , "3 started" , "1 finished" ])
1229+
11051230
11061231class TestTaskGroup (BaseTestTaskGroup , unittest .IsolatedAsyncioTestCase ):
11071232 loop_factory = asyncio .EventLoop
0 commit comments