@@ -2593,7 +2593,13 @@ def batch_isend_irecv(p2p_op_list):
25932593
25942594
25952595@_exception_logger
2596- def broadcast (tensor , src , group = None , async_op = False ):
2596+ def broadcast (
2597+ tensor : torch .Tensor ,
2598+ src : Optional [int ] = None ,
2599+ group : Optional [ProcessGroup ] = None ,
2600+ async_op : bool = False ,
2601+ group_src : Optional [int ] = None ,
2602+ ):
25972603 """
25982604 Broadcasts the tensor to the whole group.
25992605
@@ -2607,29 +2613,26 @@ def broadcast(tensor, src, group=None, async_op=False):
26072613 group (ProcessGroup, optional): The process group to work on. If None,
26082614 the default process group will be used.
26092615 async_op (bool, optional): Whether this op should be an async op
2616+ group_src (int): Source rank on ``group``. Must specify one of ``group_src``
2617+ and ``src`` but not both.
26102618
26112619 Returns:
26122620 Async work handle, if async_op is set to True.
26132621 None, if not async_op or if not part of the group
26142622
26152623 """
2624+ group = _group_or_default_group (group )
2625+ group_src = _canonicalize_group_rank (group , src , group_src , return_global = False )
26162626 _check_single_tensor (tensor , "tensor" )
26172627 if _rank_not_in_group (group ):
26182628 _warn_not_in_group ("broadcast" )
26192629 return
26202630
26212631 opts = BroadcastOptions ()
2622- opts .rootRank = src
2632+ opts .rootRank = group_src
26232633 opts .rootTensor = 0
26242634 opts .asyncOp = async_op
2625-
2626- if group is None or group is GroupMember .WORLD :
2627- default_pg = _get_default_group ()
2628- work = default_pg .broadcast ([tensor ], opts )
2629- else :
2630- group_src_rank = get_group_rank (group , src )
2631- opts .rootRank = group_src_rank
2632- work = group .broadcast ([tensor ], opts )
2635+ work = group .broadcast ([tensor ], opts )
26332636 if async_op :
26342637 return work
26352638 else :
@@ -2783,7 +2786,14 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
27832786
27842787
27852788@_exception_logger
2786- def reduce (tensor , dst , op = ReduceOp .SUM , group = None , async_op = False ):
2789+ def reduce (
2790+ tensor : torch .Tensor ,
2791+ dst : Optional [int ] = None ,
2792+ op = ReduceOp .SUM ,
2793+ group : Optional [ProcessGroup ] = None ,
2794+ async_op : bool = False ,
2795+ group_dst : Optional [int ] = None ,
2796+ ):
27872797 """
27882798 Reduces the tensor data across all machines.
27892799
@@ -2799,29 +2809,25 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
27992809 group (ProcessGroup, optional): The process group to work on. If None,
28002810 the default process group will be used.
28012811 async_op (bool, optional): Whether this op should be an async op
2812+ group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst``
2813+ and ``dst`` but not both.
28022814
28032815 Returns:
28042816 Async work handle, if async_op is set to True.
28052817 None, if not async_op or if not part of the group
28062818
28072819 """
2820+ group = _group_or_default_group (group )
2821+ group_dst = _canonicalize_group_rank (group , dst , group_dst , return_global = False )
28082822 _check_single_tensor (tensor , "tensor" )
28092823 if _rank_not_in_group (group ):
28102824 _warn_not_in_group ("reduce" )
28112825 return
28122826
28132827 opts = ReduceOptions ()
28142828 opts .reduceOp = op
2815- opts .rootRank = dst
2816-
2817- if group is None or group is GroupMember .WORLD :
2818- default_pg = _get_default_group ()
2819- work = default_pg .reduce ([tensor ], opts )
2820- else :
2821- group_dst_rank = get_group_rank (group , dst )
2822- opts .rootRank = group_dst_rank
2823- work = group .reduce ([tensor ], opts )
2824-
2829+ opts .rootRank = group_dst
2830+ work = group .reduce ([tensor ], opts )
28252831 if async_op :
28262832 return work
28272833 else :
@@ -3270,7 +3276,13 @@ def recv_object_list(object_list, src=None, group=None, device=None):
32703276
32713277
32723278@_exception_logger
3273- def broadcast_object_list (object_list , src = 0 , group = None , device = None ):
3279+ def broadcast_object_list (
3280+ object_list : List [Any ],
3281+ src : Optional [int ] = None ,
3282+ group : Optional [ProcessGroup ] = None ,
3283+ device : Optional [torch .device ] = None ,
3284+ group_src : Optional [int ] = None ,
3285+ ):
32743286 """
32753287 Broadcasts picklable objects in ``object_list`` to the whole group.
32763288
@@ -3289,6 +3301,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
32893301 device (``torch.device``, optional): If not None, the objects are
32903302 serialized and converted to tensors which are moved to the
32913303 ``device`` before broadcasting. Default is ``None``.
3304+ group_src (int): Source rank on ``group``. Must not specify one of ``group_src``
3305+ and ``src`` but not both.
32923306
32933307 Returns:
32943308 ``None``. If rank is part of the group, ``object_list`` will contain the
@@ -3331,6 +3345,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33313345 >>> objects
33323346 ['foo', 12, {1: 2}]
33333347 """
3348+ group = _group_or_default_group (group )
3349+ if src is None and group_src is None :
3350+ src = 0
3351+ global_src = _canonicalize_group_rank (group , src , group_src , return_global = True )
33343352 if _rank_not_in_group (group ):
33353353 _warn_not_in_group ("broadcast_object_list" )
33363354 return
@@ -3342,9 +3360,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33423360 # case it is not ``None`` we move the size and object tensors to be
33433361 # broadcasted to this device.
33443362 current_device = device or _get_object_coll_device (group )
3345- my_rank = get_rank ()
3363+ my_global_rank = get_rank ()
33463364 # Serialize object_list elements to tensors on src rank.
3347- if my_rank == src :
3365+ if my_global_rank == global_src :
33483366 tensor_list , size_list = zip (
33493367 * [_object_to_tensor (obj , current_device , group ) for obj in object_list ]
33503368 )
@@ -3355,12 +3373,12 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33553373 )
33563374
33573375 # Broadcast object sizes
3358- broadcast (object_sizes_tensor , src = src , group = group )
3376+ broadcast (object_sizes_tensor , src = global_src , group = group )
33593377
33603378 # Concatenate and broadcast serialized object tensors
33613379 # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
33623380 # has only one element, we can skip the copy.
3363- if my_rank == src :
3381+ if my_global_rank == global_src :
33643382 if len (tensor_list ) == 1 : # type: ignore[possibly-undefined]
33653383 object_tensor = tensor_list [0 ]
33663384 else :
@@ -3372,10 +3390,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33723390 device = current_device ,
33733391 )
33743392
3375- broadcast (object_tensor , src = src , group = group )
3393+ broadcast (object_tensor , src = global_src , group = group )
33763394 # Deserialize objects using their stored sizes.
33773395 offset = 0
3378- if my_rank != src :
3396+ if my_global_rank != global_src :
33793397 for i , obj_size in enumerate (object_sizes_tensor ):
33803398 obj_view = object_tensor [offset : offset + obj_size ]
33813399 obj_view = obj_view .type (torch .uint8 )
0 commit comments