@@ -3747,7 +3747,8 @@ def _init_two_pg2_subgroups(self, world_size: int = 4):
37473747
37483748 @requires_nccl ()
37493749 @skip_if_lt_x_gpu (4 )
3750- def test_gather_subgroup (self ):
3750+ @parametrize ("group_rank" , [True , False ])
3751+ def test_gather_subgroup (self , group_rank ):
37513752 world_size = 4
37523753 if self .rank >= world_size :
37533754 # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
@@ -3758,28 +3759,48 @@ def test_gather_subgroup(self):
37583759 input = torch .ones ((10 ,), device = device ) * self .rank
37593760 if self .rank == 0 or self .rank == 2 :
37603761 gather_list = [torch .empty_like (input ) for _ in range (subgroup .size ())]
3761- torch .distributed .gather (
3762- input ,
3763- gather_list = gather_list ,
3764- dst = self .rank ,
3765- group = subgroup ,
3766- async_op = False ,
3767- )
3762+ if group_rank :
3763+ # global_dst=0 group_dst=0 my_global_rank=2 gather_list is not None=True
3764+ torch .distributed .gather (
3765+ input ,
3766+ gather_list = gather_list ,
3767+ group_dst = 0 ,
3768+ group = subgroup ,
3769+ async_op = False ,
3770+ )
3771+ else :
3772+ torch .distributed .gather (
3773+ input ,
3774+ gather_list = gather_list ,
3775+ dst = self .rank ,
3776+ group = subgroup ,
3777+ async_op = False ,
3778+ )
37683779 for src in range (len (gather_list )):
37693780 expected = (torch .ones_like (input ) * self .rank ) + src
37703781 self .assertEqual (gather_list [src ], expected )
37713782 else :
3772- torch .distributed .gather (
3773- input ,
3774- gather_list = None ,
3775- dst = self .rank - 1 ,
3776- group = subgroup ,
3777- async_op = False ,
3778- )
3783+ if group_rank :
3784+ torch .distributed .gather (
3785+ input ,
3786+ gather_list = None ,
3787+ group_dst = 0 ,
3788+ group = subgroup ,
3789+ async_op = False ,
3790+ )
3791+ else :
3792+ torch .distributed .gather (
3793+ input ,
3794+ gather_list = None ,
3795+ dst = self .rank - 1 ,
3796+ group = subgroup ,
3797+ async_op = False ,
3798+ )
37793799
37803800 @requires_nccl ()
37813801 @skip_if_lt_x_gpu (4 )
3782- def test_gather_object_subgroup (self ):
3802+ @parametrize ("group_rank" , [True , False ])
3803+ def test_gather_object_subgroup (self , group_rank ):
37833804 world_size = 4
37843805 if self .rank >= world_size :
37853806 # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
@@ -3797,15 +3818,25 @@ def test_gather_object_subgroup(self):
37973818 # another weird thing- what's the point of making me specify some empty objects in my list?
37983819 # empty list should be valid imo. (but it throws an error)
37993820 gather_list = [{}, {}]
3800- torch .distributed .gather_object (
3801- input , object_gather_list = gather_list , dst = self .rank , group = subgroup
3802- )
3821+ if group_rank :
3822+ torch .distributed .gather_object (
3823+ input , object_gather_list = gather_list , group_dst = 0 , group = subgroup
3824+ )
3825+ else :
3826+ torch .distributed .gather_object (
3827+ input , object_gather_list = gather_list , dst = self .rank , group = subgroup
3828+ )
38033829 for src in range (len (gather_list )):
38043830 self .assertEqual (gather_list [src ]["rank" ], self .rank + src )
38053831 else :
3806- torch .distributed .gather_object (
3807- input , object_gather_list = None , dst = self .rank - 1 , group = subgroup
3808- )
3832+ if group_rank :
3833+ torch .distributed .gather_object (
3834+ input , object_gather_list = None , group_dst = 0 , group = subgroup
3835+ )
3836+ else :
3837+ torch .distributed .gather_object (
3838+ input , object_gather_list = None , dst = self .rank - 1 , group = subgroup
3839+ )
38093840
38103841 @requires_nccl ()
38113842 @skip_if_lt_x_gpu (4 )
@@ -3931,7 +3962,8 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
39313962
39323963 @requires_nccl ()
39333964 @skip_if_lt_x_gpu (4 )
3934- def test_scatter_subgroup (self ):
3965+ @parametrize ("group_rank" , [True , False ])
3966+ def test_scatter_subgroup (self , group_rank ):
39353967 world_size = 4
39363968 if self .rank >= world_size :
39373969 return
@@ -3940,18 +3972,27 @@ def test_scatter_subgroup(self):
39403972 x = torch .empty ((10 ,), device = device )
39413973 expected = torch .ones ((10 ,), device = device ) * self .rank
39423974 if self .rank == 0 or self .rank == 2 :
3943- c10d .scatter (x , scatter_list = None , src = self .rank + 1 , group = subgroup )
3975+ if group_rank :
3976+ c10d .scatter (x , scatter_list = None , group_src = 1 , group = subgroup )
3977+ else :
3978+ c10d .scatter (x , scatter_list = None , src = self .rank + 1 , group = subgroup )
39443979 else :
39453980 scatter_list = [
39463981 torch .ones ((10 ,), device = device ) * (self .rank - 1 ),
39473982 torch .ones ((10 ,), device = device ) * self .rank ,
39483983 ]
3949- c10d .scatter (x , scatter_list = scatter_list , src = self .rank , group = subgroup )
3984+ if group_rank :
3985+ c10d .scatter (x , scatter_list = scatter_list , group_src = 1 , group = subgroup )
3986+ else :
3987+ c10d .scatter (
3988+ x , scatter_list = scatter_list , src = self .rank , group = subgroup
3989+ )
39503990 self .assertEqual (x , expected )
39513991
39523992 @requires_nccl ()
39533993 @skip_if_lt_x_gpu (4 )
3954- def test_scatter_object_list_subgroup (self ):
3994+ @parametrize ("group_rank" , [True , False ])
3995+ def test_scatter_object_list_subgroup (self , group_rank ):
39553996 world_size = 4
39563997 if self .rank >= world_size :
39573998 return
@@ -3960,24 +4001,40 @@ def test_scatter_object_list_subgroup(self):
39604001 scatter_object_output_list = [None ]
39614002 expected = [{"rank" : self .rank }]
39624003 if self .rank == 0 or self .rank == 2 :
3963- c10d .scatter_object_list (
3964- scatter_object_output_list = scatter_object_output_list ,
3965- scatter_object_input_list = None ,
3966- src = self .rank + 1 ,
3967- group = subgroup ,
3968- )
4004+ if group_rank :
4005+ c10d .scatter_object_list (
4006+ scatter_object_output_list = scatter_object_output_list ,
4007+ scatter_object_input_list = None ,
4008+ group_src = 1 ,
4009+ group = subgroup ,
4010+ )
4011+ else :
4012+ c10d .scatter_object_list (
4013+ scatter_object_output_list = scatter_object_output_list ,
4014+ scatter_object_input_list = None ,
4015+ src = self .rank + 1 ,
4016+ group = subgroup ,
4017+ )
39694018
39704019 else :
39714020 scatter_object_input_list = [
39724021 {"rank" : self .rank - 1 },
39734022 {"rank" : self .rank },
39744023 ]
3975- c10d .scatter_object_list (
3976- scatter_object_output_list = scatter_object_output_list ,
3977- scatter_object_input_list = scatter_object_input_list ,
3978- src = self .rank ,
3979- group = subgroup ,
3980- )
4024+ if group_rank :
4025+ c10d .scatter_object_list (
4026+ scatter_object_output_list = scatter_object_output_list ,
4027+ scatter_object_input_list = scatter_object_input_list ,
4028+ group_src = 1 ,
4029+ group = subgroup ,
4030+ )
4031+ else :
4032+ c10d .scatter_object_list (
4033+ scatter_object_output_list = scatter_object_output_list ,
4034+ scatter_object_input_list = scatter_object_input_list ,
4035+ src = self .rank ,
4036+ group = subgroup ,
4037+ )
39814038 self .assertEqual (scatter_object_output_list , expected )
39824039
39834040
0 commit comments