Skip to content

Commit 313a597

Browse files
authored
Add differentiable collectives and axis_size / index (#183)
* Add differentiable collectives * Use API similar to JAX This requires pytorch/pytorch#164473. I had to tweak a bit the local map implementation to be able to retrieve the current device_mesh being used. * Fix type
1 parent d397aa6 commit 313a597

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

autoparallel/collectives.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any
7+
8+
import torch
9+
import torch.distributed.distributed_c10d as c10d
10+
from torch.distributed._tensor.experimental import local_map as _local_map
11+
12+
_local_map_device_mesh = None
13+
14+
15+
def local_map(*args, **kwargs):
16+
# TODO: ideally after we get out of the local map region we should
17+
# just reset the global device mesh to None. For now we just keep it
18+
# around.
19+
global _local_map_device_mesh
20+
_local_map_device_mesh = kwargs.get("device_mesh", None)
21+
return _local_map(*args, **kwargs)
22+
23+
24+
def get_mesh_from_global():
25+
global _local_map_device_mesh
26+
if _local_map_device_mesh is None:
27+
raise RuntimeError(
28+
"No mesh found, make sure to call this collective in a local_map region"
29+
)
30+
return _local_map_device_mesh
31+
32+
33+
def _get_group_name_from_axis_name(mesh_name):
34+
mesh = get_mesh_from_global()
35+
group = mesh.get_group(mesh_name)
36+
return group.group_name
37+
38+
39+
def axis_size(axis_name):
40+
mesh = get_mesh_from_global()
41+
assert axis_name in mesh.mesh_dim_names
42+
axis_dim = mesh.mesh_dim_names.index(axis_name)
43+
return mesh.size(axis_dim)
44+
45+
46+
def axis_index(axis_name):
47+
mesh = get_mesh_from_global()
48+
return mesh.get_local_rank(mesh_dim=axis_name)
49+
50+
51+
def _all_gather_tensor(
52+
x: torch.Tensor,
53+
gather_dim: int,
54+
group_name: str,
55+
) -> torch.Tensor:
56+
x = x.contiguous()
57+
group_size = c10d._get_group_size_by_name(group_name)
58+
tensor = torch.ops._c10d_functional.all_gather_into_tensor(
59+
x, group_size, group_name
60+
)
61+
res = torch.ops._c10d_functional.wait_tensor(tensor)
62+
if gather_dim != 0:
63+
# torch.cat access the data so we already need to wait here, first do wait
64+
# and then chunk + cat avoid us going through ACT dispatching logic again
65+
res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
66+
return res
67+
68+
69+
def _reduce_scatter_tensor(
70+
self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: str
71+
):
72+
group_size = c10d._get_group_size_by_name(group_name)
73+
74+
assert (
75+
self.size(scatter_dim) % group_size == 0
76+
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
77+
if scatter_dim != 0:
78+
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
79+
self = torch.cat(tensor_list)
80+
81+
tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
82+
self,
83+
reduceOp.lower(),
84+
group_size,
85+
group_name,
86+
)
87+
res = torch.ops._c10d_functional.wait_tensor(tensor)
88+
return res
89+
90+
91+
def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: str):
92+
tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
93+
res = torch.ops._c10d_functional.wait_tensor(tensor)
94+
return res
95+
96+
97+
class _AllGather(torch.autograd.Function):
98+
@staticmethod
99+
def forward(ctx: Any, x: torch.Tensor, gather_dim: int, axis_name: str):
100+
group_name = _get_group_name_from_axis_name(axis_name)
101+
ctx.group_name = group_name
102+
ctx.gather_dim = gather_dim
103+
return _all_gather_tensor(x, gather_dim, group_name)
104+
105+
@staticmethod
106+
def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override]
107+
return (
108+
_reduce_scatter_tensor(grad_output, "sum", ctx.gather_dim, ctx.group_name),
109+
None,
110+
None,
111+
)
112+
113+
114+
class _ReduceScatter(torch.autograd.Function):
115+
@staticmethod
116+
def forward(ctx: Any, x: torch.Tensor, scatter_dim: int, axis_name: str):
117+
group_name = _get_group_name_from_axis_name(axis_name)
118+
ctx.group_name = group_name
119+
ctx.scatter_dim = scatter_dim
120+
return _reduce_scatter_tensor(x, "sum", scatter_dim, group_name)
121+
122+
@staticmethod
123+
def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override]
124+
return (
125+
_all_gather_tensor(grad_output, ctx.scatter_dim, ctx.group_name),
126+
None,
127+
None,
128+
)
129+
130+
131+
class _AllReduce(torch.autograd.Function):
132+
@staticmethod
133+
def forward(ctx: Any, x: torch.Tensor, axis_name: str):
134+
group_name = _get_group_name_from_axis_name(axis_name)
135+
ctx.group_name = group_name
136+
return _all_reduce(x, "sum", group_name)
137+
138+
@staticmethod
139+
def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override]
140+
# TODO: split this into a function that does all-reduce and one which is the identity
141+
return _all_reduce(grad_output, "sum", ctx.group_name), None
142+
143+
144+
all_gather = _AllGather.apply
145+
all_reduce = _AllReduce.apply
146+
reduce_scatter = _ReduceScatter.apply

0 commit comments

Comments
 (0)