|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.distributed.distributed_c10d as c10d |
| 10 | +from torch.distributed._tensor.experimental import local_map as _local_map |
| 11 | + |
| 12 | +_local_map_device_mesh = None |
| 13 | + |
| 14 | + |
| 15 | +def local_map(*args, **kwargs): |
| 16 | + # TODO: ideally after we get out of the local map region we should |
| 17 | + # just reset the global device mesh to None. For now we just keep it |
| 18 | + # around. |
| 19 | + global _local_map_device_mesh |
| 20 | + _local_map_device_mesh = kwargs.get("device_mesh", None) |
| 21 | + return _local_map(*args, **kwargs) |
| 22 | + |
| 23 | + |
| 24 | +def get_mesh_from_global(): |
| 25 | + global _local_map_device_mesh |
| 26 | + if _local_map_device_mesh is None: |
| 27 | + raise RuntimeError( |
| 28 | + "No mesh found, make sure to call this collective in a local_map region" |
| 29 | + ) |
| 30 | + return _local_map_device_mesh |
| 31 | + |
| 32 | + |
| 33 | +def _get_group_name_from_axis_name(mesh_name): |
| 34 | + mesh = get_mesh_from_global() |
| 35 | + group = mesh.get_group(mesh_name) |
| 36 | + return group.group_name |
| 37 | + |
| 38 | + |
| 39 | +def axis_size(axis_name): |
| 40 | + mesh = get_mesh_from_global() |
| 41 | + assert axis_name in mesh.mesh_dim_names |
| 42 | + axis_dim = mesh.mesh_dim_names.index(axis_name) |
| 43 | + return mesh.size(axis_dim) |
| 44 | + |
| 45 | + |
| 46 | +def axis_index(axis_name): |
| 47 | + mesh = get_mesh_from_global() |
| 48 | + return mesh.get_local_rank(mesh_dim=axis_name) |
| 49 | + |
| 50 | + |
| 51 | +def _all_gather_tensor( |
| 52 | + x: torch.Tensor, |
| 53 | + gather_dim: int, |
| 54 | + group_name: str, |
| 55 | +) -> torch.Tensor: |
| 56 | + x = x.contiguous() |
| 57 | + group_size = c10d._get_group_size_by_name(group_name) |
| 58 | + tensor = torch.ops._c10d_functional.all_gather_into_tensor( |
| 59 | + x, group_size, group_name |
| 60 | + ) |
| 61 | + res = torch.ops._c10d_functional.wait_tensor(tensor) |
| 62 | + if gather_dim != 0: |
| 63 | + # torch.cat access the data so we already need to wait here, first do wait |
| 64 | + # and then chunk + cat avoid us going through ACT dispatching logic again |
| 65 | + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) |
| 66 | + return res |
| 67 | + |
| 68 | + |
| 69 | +def _reduce_scatter_tensor( |
| 70 | + self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: str |
| 71 | +): |
| 72 | + group_size = c10d._get_group_size_by_name(group_name) |
| 73 | + |
| 74 | + assert ( |
| 75 | + self.size(scatter_dim) % group_size == 0 |
| 76 | + ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" |
| 77 | + if scatter_dim != 0: |
| 78 | + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) |
| 79 | + self = torch.cat(tensor_list) |
| 80 | + |
| 81 | + tensor = torch.ops._c10d_functional.reduce_scatter_tensor( |
| 82 | + self, |
| 83 | + reduceOp.lower(), |
| 84 | + group_size, |
| 85 | + group_name, |
| 86 | + ) |
| 87 | + res = torch.ops._c10d_functional.wait_tensor(tensor) |
| 88 | + return res |
| 89 | + |
| 90 | + |
| 91 | +def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: str): |
| 92 | + tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) |
| 93 | + res = torch.ops._c10d_functional.wait_tensor(tensor) |
| 94 | + return res |
| 95 | + |
| 96 | + |
| 97 | +class _AllGather(torch.autograd.Function): |
| 98 | + @staticmethod |
| 99 | + def forward(ctx: Any, x: torch.Tensor, gather_dim: int, axis_name: str): |
| 100 | + group_name = _get_group_name_from_axis_name(axis_name) |
| 101 | + ctx.group_name = group_name |
| 102 | + ctx.gather_dim = gather_dim |
| 103 | + return _all_gather_tensor(x, gather_dim, group_name) |
| 104 | + |
| 105 | + @staticmethod |
| 106 | + def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] |
| 107 | + return ( |
| 108 | + _reduce_scatter_tensor(grad_output, "sum", ctx.gather_dim, ctx.group_name), |
| 109 | + None, |
| 110 | + None, |
| 111 | + ) |
| 112 | + |
| 113 | + |
| 114 | +class _ReduceScatter(torch.autograd.Function): |
| 115 | + @staticmethod |
| 116 | + def forward(ctx: Any, x: torch.Tensor, scatter_dim: int, axis_name: str): |
| 117 | + group_name = _get_group_name_from_axis_name(axis_name) |
| 118 | + ctx.group_name = group_name |
| 119 | + ctx.scatter_dim = scatter_dim |
| 120 | + return _reduce_scatter_tensor(x, "sum", scatter_dim, group_name) |
| 121 | + |
| 122 | + @staticmethod |
| 123 | + def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] |
| 124 | + return ( |
| 125 | + _all_gather_tensor(grad_output, ctx.scatter_dim, ctx.group_name), |
| 126 | + None, |
| 127 | + None, |
| 128 | + ) |
| 129 | + |
| 130 | + |
| 131 | +class _AllReduce(torch.autograd.Function): |
| 132 | + @staticmethod |
| 133 | + def forward(ctx: Any, x: torch.Tensor, axis_name: str): |
| 134 | + group_name = _get_group_name_from_axis_name(axis_name) |
| 135 | + ctx.group_name = group_name |
| 136 | + return _all_reduce(x, "sum", group_name) |
| 137 | + |
| 138 | + @staticmethod |
| 139 | + def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] |
| 140 | + # TODO: split this into a function that does all-reduce and one which is the identity |
| 141 | + return _all_reduce(grad_output, "sum", ctx.group_name), None |
| 142 | + |
| 143 | + |
| 144 | +all_gather = _AllGather.apply |
| 145 | +all_reduce = _AllReduce.apply |
| 146 | +reduce_scatter = _ReduceScatter.apply |
0 commit comments