Skip to content

Commit 4c8502d

Browse files
committed
[5/N] Autobucketing: add reordering for comm overlapping
ghstack-source-id: a001d97 Pull-Request: #134
1 parent 72ebeda commit 4c8502d

File tree

2 files changed

+280
-1
lines changed

2 files changed

+280
-1
lines changed

autoparallel/auto_bucketing.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from .autobucketing_util import bucket_func, bucket_plan, bucket_utils
8+
from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder
99

1010

1111
class simplefsdp_autobucketing_config:
@@ -71,4 +71,19 @@ def simple_fsdp_autobucketing_reordering_pass(
7171
reduce_scatter_plan,
7272
bucketable_nodes,
7373
)
74+
75+
if configs.enable_reorder_ir:
76+
print("Reorder scheduler nodes with autobucketing algroithm")
77+
node_length = len(snodes)
78+
snodes = reorder.reorder_all_gather(
79+
snodes, bucketable_nodes, all_gather_before_last_wait=False
80+
)
81+
assert node_length == len(
82+
snodes
83+
), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}"
84+
snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes)
85+
assert node_length == len(
86+
snodes
87+
), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}"
88+
7489
return snodes
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# mypy: ignore-errors
7+
from collections import defaultdict
8+
from enum import IntEnum
9+
from typing import Dict, List, Optional, Tuple
10+
11+
import torch
12+
from torch._inductor import ir, scheduler
13+
from torch._inductor.utils import contains_collective, contains_wait, is_collective
14+
from torch.utils._ordered_set import OrderedSet
15+
16+
from .bucket_utils import check_ir_node_bucketable
17+
18+
19+
class NodeType(IntEnum):
20+
ALL_GATHER = 0
21+
COMPUTE = 1
22+
REDUCE_SCATTER = 2
23+
AG_WAIT = 3
24+
RS_WAIT = 4
25+
26+
27+
def compute_node_users(
28+
snodes: List["scheduler.BaseSchedulerNode"],
29+
) -> Tuple[
30+
Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]],
31+
Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]],
32+
]:
33+
"""
34+
Compute the inverse users and users of each node
35+
"""
36+
buf_to_snode: Dict[str, scheduler.BaseSchedulerNode] = {}
37+
for node in snodes:
38+
if isinstance(node, scheduler.FusedSchedulerNode):
39+
for x in node.snodes:
40+
for buf in x.get_outputs():
41+
buf_to_snode[buf.get_name()] = node
42+
43+
for buf in node.get_outputs():
44+
buf_to_snode[buf.get_name()] = node
45+
46+
inverse_users = {}
47+
keys = list(buf_to_snode.keys())
48+
for node in snodes:
49+
dep_list = []
50+
for dep in node.unmet_dependencies:
51+
if dep.name in keys:
52+
dep_list.append(buf_to_snode[dep.name])
53+
inverse_users.update({node: OrderedSet(dep_list)})
54+
55+
node_users: Dict[
56+
scheduler.BaseSchedulerNode, OrderedSet[scheduler.BaseSchedulerNode]
57+
] = defaultdict(OrderedSet)
58+
for node, node_inverse_users in inverse_users.items():
59+
for inverse_user in node_inverse_users:
60+
node_users[inverse_user].add(node)
61+
62+
return inverse_users, node_users
63+
64+
65+
def _get_ir_node_type(ir_node: "ir.Operation", bucketable_ir_nodes) -> NodeType:
66+
"""
67+
Determine the type of a ir node
68+
"""
69+
if isinstance(ir_node, ir._WaitKernel):
70+
# Determine if the wait node is waiting for ALL_GATHER or REDUCE_SCATTER
71+
ir_op_overload = getattr(ir_node.inputs[0], "op_overload", None)
72+
if (
73+
ir_op_overload == torch.ops._c10d_functional.all_gather_into_tensor.default
74+
and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes)
75+
):
76+
return NodeType.AG_WAIT
77+
elif (
78+
ir_op_overload == torch.ops._c10d_functional.reduce_scatter_tensor.default
79+
and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes)
80+
):
81+
return NodeType.RS_WAIT
82+
if isinstance(ir_node, ir._CollectiveKernel):
83+
# Determine if the collective kernel is for ALL_GATHER or REDUCE_SCATTER
84+
ir_op_overload = getattr(ir_node, "op_overload", None)
85+
if is_collective(
86+
ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
87+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
88+
return NodeType.ALL_GATHER
89+
elif is_collective(
90+
ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
91+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
92+
return NodeType.REDUCE_SCATTER
93+
94+
if isinstance(ir_node, ir.FallbackKernel):
95+
python_kernel_name = ir_node.python_kernel_name
96+
if (
97+
python_kernel_name == "torch.ops._c10d_functional.wait_tensor.default"
98+
and check_ir_node_bucketable(ir_node, bucketable_ir_nodes)
99+
):
100+
inputs_rs_kernel_name1 = (
101+
getattr(ir_node.inputs[0], "python_kernel_name", "")
102+
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
103+
)
104+
inputs_rs_kernel_name2 = (
105+
hasattr(ir_node.inputs[0], "inputs")
106+
and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "")
107+
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
108+
)
109+
if inputs_rs_kernel_name1 or inputs_rs_kernel_name2:
110+
return NodeType.RS_WAIT
111+
112+
inputs_ag_kernel_name1 = (
113+
getattr(ir_node.inputs[0], "python_kernel_name", "")
114+
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
115+
)
116+
inputs_ag_kernel_name2 = (
117+
hasattr(ir_node.inputs[0], "inputs")
118+
and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "")
119+
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
120+
)
121+
if inputs_ag_kernel_name1 or inputs_ag_kernel_name2:
122+
return NodeType.AG_WAIT
123+
elif (
124+
python_kernel_name
125+
== "torch.ops._c10d_functional.reduce_scatter_tensor.default"
126+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
127+
return NodeType.REDUCE_SCATTER
128+
elif (
129+
python_kernel_name
130+
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
131+
) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes):
132+
return NodeType.ALL_GATHER
133+
return NodeType.COMPUTE
134+
135+
136+
def get_node_type(node: "scheduler.BaseSchedulerNode", bucketable_ir_nodes) -> NodeType:
137+
"""
138+
Determine the NodeType of a node
139+
"""
140+
if isinstance(node, scheduler.FusedSchedulerNode):
141+
# Only compute nodes are fused
142+
return NodeType.COMPUTE
143+
144+
if isinstance(node, scheduler.GroupedSchedulerNode):
145+
# [Only for bucketing]: newly created AG and RS are grouped as GroupedSchedulerNode
146+
child_nodes_type = [
147+
_get_ir_node_type(n.node, bucketable_ir_nodes) for n in node.snodes
148+
]
149+
if NodeType.AG_WAIT in child_nodes_type:
150+
return NodeType.AG_WAIT
151+
elif NodeType.RS_WAIT in child_nodes_type:
152+
return NodeType.RS_WAIT
153+
elif NodeType.ALL_GATHER in child_nodes_type:
154+
return NodeType.ALL_GATHER
155+
elif NodeType.REDUCE_SCATTER in child_nodes_type:
156+
return NodeType.REDUCE_SCATTER
157+
else:
158+
return NodeType.COMPUTE
159+
160+
return _get_ir_node_type(node.node, bucketable_ir_nodes)
161+
162+
163+
def reorder_all_gather(
164+
snodes: List["scheduler.BaseSchedulerNode"],
165+
bucketable_ir_nodes: set[str],
166+
all_gather_before_last_wait: Optional[bool] = True,
167+
) -> List["scheduler.BaseSchedulerNode"]:
168+
"""
169+
Reorder All Gather and Wait in the forward/backward pass;
170+
1. all_gather_before_last_wait set to True: all_gather_i is reordered before wait_i-1
171+
2. all_gather_before_last_wait set to False: all_gather_i is reordered after wait_i-1
172+
"""
173+
result_list: List[scheduler.BaseSchedulerNode] = []
174+
all_gather_list: List[scheduler.BaseSchedulerNode] = []
175+
node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {}
176+
inverse_users, node_users = compute_node_users(snodes)
177+
178+
for node in snodes:
179+
node_to_type[node] = get_node_type(node, bucketable_ir_nodes)
180+
snodes.reverse()
181+
for idx, node in enumerate(snodes):
182+
node_type = node_to_type[node]
183+
if node_type in [NodeType.REDUCE_SCATTER, NodeType.COMPUTE, NodeType.RS_WAIT]:
184+
# we do not reorder reduce scatter and compute node
185+
if node not in result_list and node not in all_gather_list:
186+
result_list.append(node)
187+
elif node_type == NodeType.ALL_GATHER:
188+
# gather i-th all gather node and its dependencies
189+
all_gather_list.append(node)
190+
inverse_user = list(inverse_users[node])
191+
inverse_user = [
192+
n
193+
for n in inverse_user
194+
if node_to_type[n] == NodeType.COMPUTE
195+
and not contains_collective(n)
196+
and not contains_wait(n)
197+
]
198+
if len(inverse_user) > 0:
199+
all_gather_list.extend(inverse_user)
200+
elif node_type == NodeType.AG_WAIT:
201+
if not all_gather_before_last_wait and len(all_gather_list) > 0:
202+
assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER
203+
# move i-th all gather node and its dependencies after (i-1)-th wait node (bc this is a reverse list)
204+
result_list.extend(all_gather_list)
205+
all_gather_list = []
206+
207+
result_list.append(node)
208+
209+
if all_gather_before_last_wait and len(all_gather_list) > 0:
210+
assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER
211+
# move i-th all gather node and its dependencies before (i-1)-th wait node (bc this is a reverse list)
212+
result_list.extend(all_gather_list)
213+
all_gather_list = []
214+
if len(all_gather_list) > 0:
215+
result_list.extend(all_gather_list)
216+
result_list.reverse()
217+
218+
return result_list
219+
220+
221+
def reorder_reduce_scatter(
222+
snodes: List["scheduler.BaseSchedulerNode"],
223+
bucketable_ir_nodes: set[str],
224+
) -> List["scheduler.BaseSchedulerNode"]:
225+
"""
226+
Reorder Reduce Scatter and Wait in the backward pass
227+
reorder wait_i_rs before reduce_scatter_i+1
228+
"""
229+
result_list: List[scheduler.BaseSchedulerNode] = []
230+
wait_list: List[scheduler.BaseSchedulerNode] = []
231+
node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {}
232+
inverse_users, node_users = compute_node_users(snodes)
233+
types = []
234+
for node in snodes:
235+
node_to_type[node] = get_node_type(node, bucketable_ir_nodes)
236+
types.append(get_node_type(node, bucketable_ir_nodes))
237+
238+
if NodeType.REDUCE_SCATTER not in types:
239+
return snodes
240+
241+
for idx, node in enumerate(snodes):
242+
node_type = node_to_type[node]
243+
if node_type in [NodeType.ALL_GATHER, NodeType.COMPUTE, NodeType.AG_WAIT]:
244+
if node not in result_list and node not in wait_list:
245+
result_list.append(node)
246+
elif node_type == NodeType.RS_WAIT:
247+
# there will sometimes be a memory checker node between rs and rs wait
248+
assert node_to_type[snodes[idx - 1]] == NodeType.REDUCE_SCATTER
249+
# gather wait node after reduce scatter
250+
wait_list.append(node)
251+
node_user = node_users[node]
252+
node_user = [n for n in node_user if node_to_type[n] == NodeType.COMPUTE]
253+
# wait_list.extend(node_user)
254+
elif node_type == NodeType.REDUCE_SCATTER:
255+
if len(wait_list) > 0:
256+
# move the i-th wait node before (i+1)-th reduce scatter node
257+
result_list.extend(wait_list)
258+
wait_list = []
259+
# add reduce scatter node
260+
result_list.append(node)
261+
262+
if len(wait_list) > 0:
263+
result_list.extend(wait_list)
264+
return result_list

0 commit comments

Comments
 (0)